From 5d62b56f7e9b79e1fb5d00d50512da7e2d71d481 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Tue, 5 Aug 2025 18:30:19 -0700 Subject: [PATCH] [router] complete router oai spec (#8828) --- sgl-router/benches/request_processing.rs | 169 ++++--- sgl-router/src/openai_api_types.rs | 203 +++++++- sgl-router/src/routers/pd_types.rs | 106 ++--- sgl-router/src/routers/request_adapter.rs | 552 +++++++++++++++------- sgl-router/tests/benchmark_integration.rs | 189 +++++--- 5 files changed, 855 insertions(+), 364 deletions(-) diff --git a/sgl-router/benches/request_processing.rs b/sgl-router/benches/request_processing.rs index db5cdc901..a997b8dfd 100644 --- a/sgl-router/benches/request_processing.rs +++ b/sgl-router/benches/request_processing.rs @@ -8,12 +8,116 @@ use sglang_router_rs::openai_api_types::{ }; use sglang_router_rs::routers::request_adapter::{RouteableRequest, ToPdRequest}; +/// Create a default GenerateRequest for benchmarks with minimal fields set +fn default_generate_request() -> GenerateRequest { + GenerateRequest { + text: None, + prompt: None, + input_ids: None, + stream: false, + parameters: None, + sampling_params: None, + return_logprob: false, + // SGLang Extensions + lora_path: None, + session_params: None, + return_hidden_states: false, + rid: None, + } +} + +/// Create a default ChatCompletionRequest for benchmarks with minimal fields set +fn default_chat_completion_request() -> ChatCompletionRequest { + ChatCompletionRequest { + model: String::new(), + messages: vec![], + max_tokens: None, + max_completion_tokens: None, + temperature: None, + top_p: None, + n: None, + stream: false, + stream_options: None, + stop: None, + presence_penalty: None, + frequency_penalty: None, + logit_bias: None, + logprobs: false, + top_logprobs: None, + user: None, + response_format: None, + seed: None, + tools: None, + tool_choice: None, + parallel_tool_calls: None, + function_call: None, + functions: None, + // SGLang Extensions + top_k: None, + min_p: None, + min_tokens: None, + repetition_penalty: None, + regex: None, + ebnf: None, + stop_token_ids: None, + no_stop_trim: false, + ignore_eos: false, + continue_final_message: false, + skip_special_tokens: true, + // SGLang Extensions + lora_path: None, + session_params: None, + separate_reasoning: true, + stream_reasoning: true, + return_hidden_states: false, + } +} + +/// Create a default CompletionRequest for benchmarks with minimal fields set +fn default_completion_request() -> CompletionRequest { + CompletionRequest { + model: String::new(), + prompt: StringOrArray::String(String::new()), + suffix: None, + max_tokens: None, + temperature: None, + top_p: None, + n: None, + stream: false, + stream_options: None, + logprobs: None, + echo: false, + stop: None, + presence_penalty: None, + frequency_penalty: None, + best_of: None, + logit_bias: None, + user: None, + seed: None, + // SGLang Extensions + top_k: None, + min_p: None, + min_tokens: None, + repetition_penalty: None, + regex: None, + ebnf: None, + json_schema: None, + stop_token_ids: None, + no_stop_trim: false, + ignore_eos: false, + skip_special_tokens: true, + // SGLang Extensions + lora_path: None, + session_params: None, + return_hidden_states: false, + other: serde_json::Map::new(), + } +} + // Sample request data for benchmarks fn create_sample_generate_request() -> GenerateRequest { GenerateRequest { text: Some("Write a story about artificial intelligence".to_string()), - input_ids: None, - prompt: None, parameters: Some(GenerateParameters { max_new_tokens: Some(100), temperature: Some(0.8), @@ -31,8 +135,7 @@ fn create_sample_generate_request() -> GenerateRequest { repetition_penalty: Some(1.0), ..Default::default() }), - stream: false, - return_logprob: false, + ..default_generate_request() } } @@ -58,22 +161,10 @@ fn create_sample_chat_completion_request() -> ChatCompletionRequest { temperature: Some(0.7), top_p: Some(1.0), n: Some(1), - stream: false, - stream_options: None, - stop: None, presence_penalty: Some(0.0), frequency_penalty: Some(0.0), - logit_bias: None, - logprobs: false, - top_logprobs: None, - user: None, - response_format: None, - seed: None, - tools: None, - tool_choice: None, parallel_tool_calls: Some(true), - function_call: None, - functions: None, + ..default_chat_completion_request() } } @@ -81,23 +172,14 @@ fn create_sample_completion_request() -> CompletionRequest { CompletionRequest { model: "text-davinci-003".to_string(), prompt: StringOrArray::String("Complete this sentence: The future of AI is".to_string()), - suffix: None, max_tokens: Some(50), temperature: Some(0.8), top_p: Some(1.0), n: Some(1), - stream: false, - stream_options: None, - logprobs: None, - echo: false, - stop: None, presence_penalty: Some(0.0), frequency_penalty: Some(0.0), best_of: Some(1), - logit_bias: None, - user: None, - seed: None, - other: serde_json::Map::new(), + ..default_completion_request() } } @@ -121,6 +203,7 @@ fn create_large_chat_completion_request() -> ChatCompletionRequest { name: None, tool_calls: None, function_call: None, + reasoning_content: None, }); } @@ -132,22 +215,13 @@ fn create_large_chat_completion_request() -> ChatCompletionRequest { temperature: Some(0.7), top_p: Some(0.95), n: Some(1), - stream: false, - stream_options: None, - stop: None, presence_penalty: Some(0.1), frequency_penalty: Some(0.1), - logit_bias: None, - logprobs: false, top_logprobs: Some(5), user: Some("benchmark_user".to_string()), - response_format: None, seed: Some(42), - tools: None, - tool_choice: None, parallel_tool_calls: Some(true), - function_call: None, - functions: None, + ..default_chat_completion_request() } } @@ -331,32 +405,17 @@ fn bench_throughput_by_size(c: &mut Criterion) { // Create requests of different sizes let small_generate = GenerateRequest { text: Some("Hi".to_string()), - input_ids: None, - prompt: None, - parameters: None, - sampling_params: None, - stream: false, - return_logprob: false, + ..default_generate_request() }; let medium_generate = GenerateRequest { text: Some("Write a medium length story about AI".repeat(10)), - input_ids: None, - prompt: None, - parameters: None, - sampling_params: None, - stream: false, - return_logprob: false, + ..default_generate_request() }; let large_generate = GenerateRequest { text: Some("Write a very long and detailed story about artificial intelligence and its impact on society".repeat(100)), - input_ids: None, - prompt: None, - parameters: None, - sampling_params: None, - stream: false, - return_logprob: false, + ..default_generate_request() }; for (name, req) in [ diff --git a/sgl-router/src/openai_api_types.rs b/sgl-router/src/openai_api_types.rs index d57e61767..4a0fb0ee0 100644 --- a/sgl-router/src/openai_api_types.rs +++ b/sgl-router/src/openai_api_types.rs @@ -6,6 +6,21 @@ use serde::{Deserialize, Serialize}; use serde_json::Value; use std::collections::HashMap; +/// Helper function for serde default value +fn default_true() -> bool { + true +} + +// ============= SGLang-Specific Types ============= + +/// LoRA adapter path - can be single path or batch of paths +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(untagged)] +pub enum LoRAPath { + Single(Option), + Batch(Vec>), +} + /// Common trait for all generation requests pub trait GenerationRequest: Send + Sync { /// Check if the request is for streaming @@ -92,6 +107,64 @@ pub struct CompletionRequest { #[serde(skip_serializing_if = "Option::is_none")] pub seed: Option, + // ============= SGLang Extensions ============= + /// Top-k sampling parameter (-1 to disable) + #[serde(skip_serializing_if = "Option::is_none")] + pub top_k: Option, + + /// Min-p nucleus sampling parameter + #[serde(skip_serializing_if = "Option::is_none")] + pub min_p: Option, + + /// Minimum number of tokens to generate + #[serde(skip_serializing_if = "Option::is_none")] + pub min_tokens: Option, + + /// Repetition penalty for reducing repetitive text + #[serde(skip_serializing_if = "Option::is_none")] + pub repetition_penalty: Option, + + /// Regex constraint for output generation + #[serde(skip_serializing_if = "Option::is_none")] + pub regex: Option, + + /// EBNF grammar constraint for structured output + #[serde(skip_serializing_if = "Option::is_none")] + pub ebnf: Option, + + /// JSON schema constraint for structured output + #[serde(skip_serializing_if = "Option::is_none")] + pub json_schema: Option, + + /// Specific token IDs to use as stop conditions + #[serde(skip_serializing_if = "Option::is_none")] + pub stop_token_ids: Option>, + + /// Skip trimming stop tokens from output + #[serde(default)] + pub no_stop_trim: bool, + + /// Ignore end-of-sequence tokens during generation + #[serde(default)] + pub ignore_eos: bool, + + /// Skip special tokens during detokenization + #[serde(default = "default_true")] + pub skip_special_tokens: bool, + + // ============= SGLang Extensions ============= + /// Path to LoRA adapter(s) for model customization + #[serde(skip_serializing_if = "Option::is_none")] + pub lora_path: Option, + + /// Session parameters for continual prompting + #[serde(skip_serializing_if = "Option::is_none")] + pub session_params: Option>, + + /// Return model hidden states + #[serde(default)] + pub return_hidden_states: bool, + /// Additional fields including bootstrap info for PD routing #[serde(flatten)] pub other: serde_json::Map, @@ -166,7 +239,7 @@ pub struct ChatCompletionRequest { /// Modify the likelihood of specified tokens appearing in the completion #[serde(skip_serializing_if = "Option::is_none")] - pub logit_bias: Option>, + pub logit_bias: Option>, /// A unique identifier representing your end-user #[serde(skip_serializing_if = "Option::is_none")] @@ -207,6 +280,72 @@ pub struct ChatCompletionRequest { /// Deprecated: use tool_choice instead #[serde(skip_serializing_if = "Option::is_none")] pub function_call: Option, + + // ============= SGLang Extensions ============= + /// Top-k sampling parameter (-1 to disable) + #[serde(skip_serializing_if = "Option::is_none")] + pub top_k: Option, + + /// Min-p nucleus sampling parameter + #[serde(skip_serializing_if = "Option::is_none")] + pub min_p: Option, + + /// Minimum number of tokens to generate + #[serde(skip_serializing_if = "Option::is_none")] + pub min_tokens: Option, + + /// Repetition penalty for reducing repetitive text + #[serde(skip_serializing_if = "Option::is_none")] + pub repetition_penalty: Option, + + /// Regex constraint for output generation + #[serde(skip_serializing_if = "Option::is_none")] + pub regex: Option, + + /// EBNF grammar constraint for structured output + #[serde(skip_serializing_if = "Option::is_none")] + pub ebnf: Option, + + /// Specific token IDs to use as stop conditions + #[serde(skip_serializing_if = "Option::is_none")] + pub stop_token_ids: Option>, + + /// Skip trimming stop tokens from output + #[serde(default)] + pub no_stop_trim: bool, + + /// Ignore end-of-sequence tokens during generation + #[serde(default)] + pub ignore_eos: bool, + + /// Continue generating from final assistant message + #[serde(default)] + pub continue_final_message: bool, + + /// Skip special tokens during detokenization + #[serde(default = "default_true")] + pub skip_special_tokens: bool, + + // ============= SGLang Extensions ============= + /// Path to LoRA adapter(s) for model customization + #[serde(skip_serializing_if = "Option::is_none")] + pub lora_path: Option, + + /// Session parameters for continual prompting + #[serde(skip_serializing_if = "Option::is_none")] + pub session_params: Option>, + + /// Separate reasoning content from final answer (O1-style models) + #[serde(default = "default_true")] + pub separate_reasoning: bool, + + /// Stream reasoning tokens during generation + #[serde(default = "default_true")] + pub stream_reasoning: bool, + + /// Return model hidden states + #[serde(default)] + pub return_hidden_states: bool, } #[derive(Debug, Clone, Deserialize, Serialize)] @@ -234,6 +373,9 @@ pub enum ChatMessage { tool_calls: Option>, #[serde(skip_serializing_if = "Option::is_none")] function_call: Option, + /// Reasoning content for O1-style models (SGLang extension) + #[serde(skip_serializing_if = "Option::is_none")] + reasoning_content: Option, }, Tool { role: String, // "tool" @@ -378,7 +520,20 @@ impl GenerationRequest for ChatCompletionRequest { Some(texts.join(" ")) } }, - ChatMessage::Assistant { content, .. } => content.clone(), + ChatMessage::Assistant { + content, + reasoning_content, + .. + } => { + // Combine content and reasoning content for routing decisions + let main_content = content.clone().unwrap_or_default(); + let reasoning = reasoning_content.clone().unwrap_or_default(); + if main_content.is_empty() && reasoning.is_empty() { + None + } else { + Some(format!("{} {}", main_content, reasoning).trim().to_string()) + } + } ChatMessage::Tool { content, .. } => Some(content.clone()), ChatMessage::Function { content, .. } => Some(content.clone()), }) @@ -418,6 +573,23 @@ pub struct GenerateRequest { /// Whether to return logprobs #[serde(default)] pub return_logprob: bool, + + // ============= SGLang Extensions ============= + /// Path to LoRA adapter(s) for model customization + #[serde(skip_serializing_if = "Option::is_none")] + pub lora_path: Option, + + /// Session parameters for continual prompting + #[serde(skip_serializing_if = "Option::is_none")] + pub session_params: Option>, + + /// Return model hidden states + #[serde(default)] + pub return_hidden_states: bool, + + /// Request ID for tracking + #[serde(skip_serializing_if = "Option::is_none")] + pub rid: Option, } #[derive(Debug, Clone, Deserialize, Serialize)] @@ -485,6 +657,18 @@ pub struct SamplingParams { pub skip_special_tokens: Option, #[serde(skip_serializing_if = "Option::is_none")] pub json_schema: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub regex: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub ebnf: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub min_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub min_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stop_token_ids: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub no_stop_trim: Option, } impl GenerationRequest for GenerateRequest { @@ -561,6 +745,12 @@ pub struct CompletionChoice { #[serde(skip_serializing_if = "Option::is_none")] pub logprobs: Option, pub finish_reason: Option, // "stop", "length", "content_filter", etc. + /// Information about which stop condition was matched + #[serde(skip_serializing_if = "Option::is_none")] + pub matched_stop: Option, // Can be string or integer + /// Hidden states from the model (SGLang extension) + #[serde(skip_serializing_if = "Option::is_none")] + pub hidden_states: Option>, } #[derive(Debug, Clone, Deserialize, Serialize)] @@ -591,6 +781,12 @@ pub struct ChatChoice { #[serde(skip_serializing_if = "Option::is_none")] pub logprobs: Option, pub finish_reason: Option, // "stop", "length", "tool_calls", "content_filter", "function_call" + /// Information about which stop condition was matched + #[serde(skip_serializing_if = "Option::is_none")] + pub matched_stop: Option, // Can be string or integer + /// Hidden states from the model (SGLang extension) + #[serde(skip_serializing_if = "Option::is_none")] + pub hidden_states: Option>, } #[derive(Debug, Clone, Deserialize, Serialize)] @@ -681,6 +877,9 @@ pub struct ChatMessageDelta { pub tool_calls: Option>, #[serde(skip_serializing_if = "Option::is_none")] pub function_call: Option, + /// Reasoning content delta for O1-style models (SGLang extension) + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_content: Option, } #[derive(Debug, Clone, Deserialize, Serialize)] diff --git a/sgl-router/src/routers/pd_types.rs b/sgl-router/src/routers/pd_types.rs index ce13977d6..34dabdd26 100644 --- a/sgl-router/src/routers/pd_types.rs +++ b/sgl-router/src/routers/pd_types.rs @@ -278,11 +278,11 @@ mod bootstrap_tests { use crate::core::BasicWorker; use crate::openai_api_types::StringOrArray; - #[test] - fn test_completion_batch_size_with_array_prompt() { - let req = CompletionRequest { - model: "test".to_string(), - prompt: StringOrArray::Array(vec!["prompt1".to_string(), "prompt2".to_string()]), + /// Create a default CompletionRequest for testing with minimal fields set + fn default_completion_request() -> CompletionRequest { + CompletionRequest { + model: String::new(), + prompt: StringOrArray::String(String::new()), n: None, other: serde_json::Map::new(), suffix: None, @@ -300,6 +300,31 @@ mod bootstrap_tests { logit_bias: None, user: None, seed: None, + // SGLang Extensions + top_k: None, + min_p: None, + min_tokens: None, + repetition_penalty: None, + regex: None, + ebnf: None, + json_schema: None, + stop_token_ids: None, + no_stop_trim: false, + ignore_eos: false, + skip_special_tokens: true, + // SGLang Extensions + lora_path: None, + session_params: None, + return_hidden_states: false, + } + } + + #[test] + fn test_completion_batch_size_with_array_prompt() { + let req = CompletionRequest { + model: "test".to_string(), + prompt: StringOrArray::Array(vec!["prompt1".to_string(), "prompt2".to_string()]), + ..default_completion_request() }; // Should return batch size for array prompt @@ -311,23 +336,7 @@ mod bootstrap_tests { let req = CompletionRequest { model: "test".to_string(), prompt: StringOrArray::String("single prompt".to_string()), - n: None, - other: serde_json::Map::new(), - suffix: None, - max_tokens: None, - temperature: None, - top_p: None, - stream: false, - stream_options: None, - logprobs: None, - echo: false, - stop: None, - presence_penalty: None, - frequency_penalty: None, - best_of: None, - logit_bias: None, - user: None, - seed: None, + ..default_completion_request() }; // Should return None for single prompt @@ -340,22 +349,7 @@ mod bootstrap_tests { model: "test".to_string(), prompt: StringOrArray::String("single prompt".to_string()), n: Some(3), - other: serde_json::Map::new(), - suffix: None, - max_tokens: None, - temperature: None, - top_p: None, - stream: false, - stream_options: None, - logprobs: None, - echo: false, - stop: None, - presence_penalty: None, - frequency_penalty: None, - best_of: None, - logit_bias: None, - user: None, - seed: None, + ..default_completion_request() }; // Should return None for single string prompt, even with n > 1 @@ -368,23 +362,7 @@ mod bootstrap_tests { let mut req = CompletionRequest { model: "test".to_string(), prompt: StringOrArray::Array(vec!["prompt1".to_string(), "prompt2".to_string()]), - n: None, - other: serde_json::Map::new(), - suffix: None, - max_tokens: None, - temperature: None, - top_p: None, - stream: false, - stream_options: None, - logprobs: None, - echo: false, - stop: None, - presence_penalty: None, - frequency_penalty: None, - best_of: None, - logit_bias: None, - user: None, - seed: None, + ..default_completion_request() }; // Set bootstrap info - should always use single values @@ -418,23 +396,7 @@ mod bootstrap_tests { let mut req = CompletionRequest { model: "test".to_string(), prompt: StringOrArray::Array(vec!["prompt1".to_string(), "prompt2".to_string()]), - n: None, - other: serde_json::Map::new(), - suffix: None, - max_tokens: None, - temperature: None, - top_p: None, - stream: false, - stream_options: None, - logprobs: None, - echo: false, - stop: None, - presence_penalty: None, - frequency_penalty: None, - best_of: None, - logit_bias: None, - user: None, - seed: None, + ..default_completion_request() }; // Set bootstrap info with arrays diff --git a/sgl-router/src/routers/request_adapter.rs b/sgl-router/src/routers/request_adapter.rs index f29bcecc9..809244793 100644 --- a/sgl-router/src/routers/request_adapter.rs +++ b/sgl-router/src/routers/request_adapter.rs @@ -176,6 +176,33 @@ impl ToPdRequest for CompletionRequest { self.stream => "stream" ); + // Add SGLang extension fields + insert_if_some!(other, + // SGLang Extensions - Priority 1 + self.top_k => "top_k", + self.min_p => "min_p", + self.min_tokens => "min_tokens", + self.repetition_penalty => "repetition_penalty", + self.regex => "regex", + self.ebnf => "ebnf", + self.stop_token_ids => "stop_token_ids", + // SGLang Extensions - Priority 2 + self.lora_path => "lora_path", + self.session_params => "session_params" + ); + + // SGLang boolean extensions (CompletionRequest has these as bool, not Option) + other.insert("no_stop_trim".to_string(), self.no_stop_trim.into()); + other.insert("ignore_eos".to_string(), self.ignore_eos.into()); + other.insert( + "skip_special_tokens".to_string(), + self.skip_special_tokens.into(), + ); + other.insert( + "return_hidden_states".to_string(), + self.return_hidden_states.into(), + ); + GenerateReqInput { text, input_ids: None, @@ -226,14 +253,46 @@ impl ToPdRequest for ChatCompletionRequest { self.tool_choice => "tool_choice", self.parallel_tool_calls => "parallel_tool_calls", self.functions => "functions", - self.function_call => "function_call" + self.function_call => "function_call", + // SGLang Extensions - Priority 1 + self.top_k => "top_k", + self.min_p => "min_p", + self.min_tokens => "min_tokens", + self.repetition_penalty => "repetition_penalty", + self.regex => "regex", + self.ebnf => "ebnf", + self.stop_token_ids => "stop_token_ids", + // SGLang Extensions - Priority 2 + self.lora_path => "lora_path", + self.session_params => "session_params" ); - // Handle boolean logprobs flag + // Handle boolean flags if self.logprobs { other.insert("logprobs".to_string(), true.into()); } + // SGLang boolean extensions (ChatCompletionRequest has these as bool, not Option) + other.insert("no_stop_trim".to_string(), self.no_stop_trim.into()); + other.insert("ignore_eos".to_string(), self.ignore_eos.into()); + other.insert( + "continue_final_message".to_string(), + self.continue_final_message.into(), + ); + other.insert( + "skip_special_tokens".to_string(), + self.skip_special_tokens.into(), + ); + other.insert( + "separate_reasoning".to_string(), + self.separate_reasoning.into(), + ); + other.insert("stream_reasoning".to_string(), self.stream_reasoning.into()); + other.insert( + "return_hidden_states".to_string(), + self.return_hidden_states.into(), + ); + ChatReqInput { stream: self.stream, bootstrap_host: None, @@ -271,18 +330,136 @@ mod tests { use serde_json::json; use std::collections::HashMap; - // ============= GenerateRequest to_pd_request Tests ============= + // ============= Test Helper Functions ============= + // + // These helper functions create default request instances with all required SGLang extension fields + // properly initialized. Use the struct spread operator `..default_*_request()` to override only + // the fields you need for specific tests, avoiding repetitive boilerplate code. + // + // Example usage: + // let req = GenerateRequest { + // text: Some("Custom text".to_string()), + // stream: true, + // ..default_generate_request() + // }; - #[test] - fn test_generate_to_pd_request_with_text_only() { - let req = GenerateRequest { - text: Some("Hello world".to_string()), + /// Create a default GenerateRequest with minimal fields set + fn default_generate_request() -> GenerateRequest { + GenerateRequest { + text: None, prompt: None, input_ids: None, stream: false, parameters: None, sampling_params: None, return_logprob: false, + // SGLang Extensions + lora_path: None, + session_params: None, + return_hidden_states: false, + rid: None, + } + } + + /// Create a default CompletionRequest with minimal fields set + fn default_completion_request() -> CompletionRequest { + CompletionRequest { + model: "test-model".to_string(), + prompt: StringOrArray::String("test prompt".to_string()), + max_tokens: None, + temperature: None, + top_p: None, + n: None, + stream: false, + stream_options: None, + logprobs: None, + echo: false, + stop: None, + presence_penalty: None, + frequency_penalty: None, + best_of: None, + logit_bias: None, + user: None, + seed: None, + suffix: None, + // SGLang Extensions + top_k: None, + min_p: None, + min_tokens: None, + repetition_penalty: None, + regex: None, + ebnf: None, + json_schema: None, + stop_token_ids: None, + no_stop_trim: false, + ignore_eos: false, + skip_special_tokens: true, + // SGLang Extensions + lora_path: None, + session_params: None, + return_hidden_states: false, + other: serde_json::Map::new(), + } + } + + /// Create a default ChatCompletionRequest with minimal fields set + fn default_chat_completion_request() -> ChatCompletionRequest { + ChatCompletionRequest { + model: "test-model".to_string(), + messages: vec![ChatMessage::User { + role: "user".to_string(), + content: UserMessageContent::Text("test message".to_string()), + name: None, + }], + temperature: None, + top_p: None, + n: None, + stream: false, + stream_options: None, + stop: None, + max_tokens: None, + max_completion_tokens: None, + presence_penalty: None, + frequency_penalty: None, + logit_bias: None, + logprobs: false, + top_logprobs: None, + user: None, + seed: None, + response_format: None, + tools: None, + tool_choice: None, + parallel_tool_calls: None, + functions: None, + function_call: None, + // SGLang Extensions + top_k: None, + min_p: None, + min_tokens: None, + repetition_penalty: None, + regex: None, + ebnf: None, + stop_token_ids: None, + no_stop_trim: false, + ignore_eos: false, + continue_final_message: false, + skip_special_tokens: true, + // SGLang Extensions + lora_path: None, + session_params: None, + separate_reasoning: true, + stream_reasoning: true, + return_hidden_states: false, + } + } + + // ============= GenerateRequest to_pd_request Tests ============= + + #[test] + fn test_generate_to_pd_request_with_text_only() { + let req = GenerateRequest { + text: Some("Hello world".to_string()), + ..default_generate_request() }; let pd_req = req.to_pd_request(); @@ -308,13 +485,10 @@ mod tests { #[test] fn test_generate_to_pd_request_with_prompt_string() { let req = GenerateRequest { - text: None, prompt: Some(StringOrArray::String("Test prompt".to_string())), - input_ids: None, stream: true, - parameters: None, - sampling_params: None, return_logprob: true, + ..default_generate_request() }; let pd_req = req.to_pd_request(); @@ -342,6 +516,7 @@ mod tests { parameters: None, sampling_params: None, return_logprob: false, + ..default_generate_request() }; let pd_req = req.to_pd_request(); @@ -360,13 +535,8 @@ mod tests { #[test] fn test_generate_to_pd_request_with_single_input_ids() { let req = GenerateRequest { - text: None, - prompt: None, input_ids: Some(InputIds::Single(vec![100, 200, 300, 400])), - stream: false, - parameters: None, - sampling_params: None, - return_logprob: false, + ..default_generate_request() }; let pd_req = req.to_pd_request(); @@ -381,17 +551,12 @@ mod tests { #[test] fn test_generate_to_pd_request_with_batch_input_ids() { let req = GenerateRequest { - text: None, - prompt: None, input_ids: Some(InputIds::Batch(vec![ vec![1, 2, 3], vec![4, 5, 6, 7], vec![8, 9], ])), - stream: false, - parameters: None, - sampling_params: None, - return_logprob: false, + ..default_generate_request() }; let pd_req = req.to_pd_request(); @@ -413,10 +578,7 @@ mod tests { text: Some("SGLang text".to_string()), prompt: Some(StringOrArray::String("OpenAI prompt".to_string())), input_ids: Some(InputIds::Single(vec![1, 2, 3])), - stream: false, - parameters: None, - sampling_params: None, - return_logprob: false, + ..default_generate_request() }; let pd_req = req.to_pd_request(); @@ -429,13 +591,9 @@ mod tests { #[test] fn test_generate_to_pd_request_priority_prompt_over_input_ids() { let req = GenerateRequest { - text: None, prompt: Some(StringOrArray::String("OpenAI prompt".to_string())), input_ids: Some(InputIds::Single(vec![1, 2, 3])), - stream: false, - parameters: None, - sampling_params: None, - return_logprob: false, + ..default_generate_request() }; let pd_req = req.to_pd_request(); @@ -459,12 +617,8 @@ mod tests { let req = GenerateRequest { text: Some("test".to_string()), - prompt: None, - input_ids: None, - stream: false, parameters: Some(params), - sampling_params: None, - return_logprob: false, + ..default_generate_request() }; let pd_req = req.to_pd_request(); @@ -497,12 +651,8 @@ mod tests { let req = GenerateRequest { text: Some("test".to_string()), - prompt: None, - input_ids: None, - stream: false, - parameters: None, sampling_params: Some(sampling), - return_logprob: false, + ..default_generate_request() }; let pd_req = req.to_pd_request(); @@ -546,6 +696,7 @@ mod tests { parameters: Some(params), sampling_params: Some(sampling), return_logprob: false, + ..default_generate_request() }; let pd_req = req.to_pd_request(); @@ -568,6 +719,7 @@ mod tests { parameters: Some(params), sampling_params: None, return_logprob: false, + ..default_generate_request() }; let pd_req = req.to_pd_request(); @@ -603,6 +755,7 @@ mod tests { parameters: Some(params), sampling_params: Some(sampling), return_logprob: true, + ..default_generate_request() }; let pd_req = req.to_pd_request(); @@ -632,23 +785,7 @@ mod tests { let req = CompletionRequest { model: "gpt-3.5-turbo".to_string(), prompt: StringOrArray::String("Complete this sentence".to_string()), - max_tokens: None, - temperature: None, - top_p: None, - n: None, - stream: false, - stream_options: None, - logprobs: None, - echo: false, - stop: None, - presence_penalty: None, - frequency_penalty: None, - best_of: None, - logit_bias: None, - user: None, - seed: None, - suffix: None, - other: serde_json::Map::new(), + ..default_completion_request() }; let pd_req = req.to_pd_request(); @@ -672,23 +809,7 @@ mod tests { "First prompt".to_string(), "Second prompt".to_string(), ]), - max_tokens: None, - temperature: None, - top_p: None, - n: None, - stream: false, - stream_options: None, - logprobs: None, - echo: false, - stop: None, - presence_penalty: None, - frequency_penalty: None, - best_of: None, - logit_bias: None, - user: None, - seed: None, - suffix: None, - other: serde_json::Map::new(), + ..default_completion_request() }; let pd_req = req.to_pd_request(); @@ -727,7 +848,7 @@ mod tests { user: Some("user123".to_string()), seed: Some(42), suffix: Some("...".to_string()), - other: serde_json::Map::new(), + ..default_completion_request() }; let pd_req = req.to_pd_request(); @@ -771,7 +892,7 @@ mod tests { user: None, seed: None, suffix: None, - other: serde_json::Map::new(), + ..default_completion_request() }; let pd_req = req.to_pd_request(); @@ -803,7 +924,7 @@ mod tests { user: None, seed: None, suffix: None, - other: serde_json::Map::new(), + ..default_completion_request() }; let pd_req = req.to_pd_request(); @@ -834,27 +955,7 @@ mod tests { let req = ChatCompletionRequest { messages, model: "gpt-4".to_string(), - temperature: None, - top_p: None, - n: None, - stream: false, - stream_options: None, - stop: None, - max_tokens: None, - max_completion_tokens: None, - presence_penalty: None, - frequency_penalty: None, - logit_bias: None, - logprobs: false, - top_logprobs: None, - user: None, - seed: None, - response_format: None, - tools: None, - tool_choice: None, - parallel_tool_calls: None, - functions: None, - function_call: None, + ..default_chat_completion_request() }; let pd_req = req.to_pd_request(); @@ -883,7 +984,7 @@ mod tests { }]; let mut logit_bias = HashMap::new(); - logit_bias.insert("50256".to_string(), -100); + logit_bias.insert("50256".to_string(), -100.0f32); let tool = Tool { tool_type: "function".to_string(), @@ -920,6 +1021,7 @@ mod tests { parallel_tool_calls: Some(false), functions: None, function_call: None, + ..default_chat_completion_request() }; let pd_req = req.to_pd_request(); @@ -968,27 +1070,7 @@ mod tests { let req = ChatCompletionRequest { messages, model: "gpt-4-vision".to_string(), - temperature: None, - top_p: None, - n: None, - stream: false, - stream_options: None, - stop: None, - max_tokens: None, - max_completion_tokens: None, - presence_penalty: None, - frequency_penalty: None, - logit_bias: None, - logprobs: false, - top_logprobs: None, - user: None, - seed: None, - response_format: None, - tools: None, - tool_choice: None, - parallel_tool_calls: None, - functions: None, - function_call: None, + ..default_chat_completion_request() }; let pd_req = req.to_pd_request(); @@ -1037,6 +1119,7 @@ mod tests { parallel_tool_calls: None, functions: None, function_call: None, + ..default_chat_completion_request() }; let pd_req = req.to_pd_request(); @@ -1054,32 +1137,13 @@ mod tests { name: None, tool_calls: None, function_call: None, + reasoning_content: None, }]; let req = ChatCompletionRequest { messages, model: "gpt-3.5-turbo".to_string(), - temperature: None, - top_p: None, - n: None, - stream: false, - stream_options: None, - stop: None, - max_tokens: None, - max_completion_tokens: None, - presence_penalty: None, - frequency_penalty: None, - logit_bias: None, - logprobs: false, - top_logprobs: None, - user: None, - seed: None, - response_format: None, - tools: None, - tool_choice: None, - parallel_tool_calls: None, - functions: None, - function_call: None, + ..default_chat_completion_request() }; let pd_req = req.to_pd_request(); @@ -1101,12 +1165,7 @@ mod tests { fn test_routeable_request_to_json() { let req = GenerateRequest { text: Some("test".to_string()), - prompt: None, - input_ids: None, - stream: false, - parameters: None, - sampling_params: None, - return_logprob: false, + ..default_generate_request() }; let json = req.to_json().unwrap(); @@ -1166,6 +1225,7 @@ mod tests { parameters: Some(params), sampling_params: None, return_logprob: false, + ..default_generate_request() }; let pd_req = req.to_pd_request(); @@ -1187,6 +1247,7 @@ mod tests { parameters: None, sampling_params: None, return_logprob: false, + ..default_generate_request() }; let pd_req = req.to_pd_request(); @@ -1206,12 +1267,7 @@ mod tests { let req = GenerateRequest { text: Some(unicode_text.clone()), - prompt: None, - input_ids: None, - stream: false, - parameters: None, - sampling_params: None, - return_logprob: false, + ..default_generate_request() }; let pd_req = req.to_pd_request(); @@ -1250,6 +1306,7 @@ mod tests { parameters: Some(params), sampling_params: None, return_logprob: false, + ..default_generate_request() }; let pd_req = req.to_pd_request(); @@ -1265,12 +1322,7 @@ mod tests { fn test_bootstrap_fields_none() { let req = GenerateRequest { text: Some("test".to_string()), - prompt: None, - input_ids: None, - stream: false, - parameters: None, - sampling_params: None, - return_logprob: false, + ..default_generate_request() }; let pd_req = req.to_pd_request(); @@ -1279,4 +1331,182 @@ mod tests { assert_eq!(pd_req.bootstrap_port, None); assert_eq!(pd_req.bootstrap_room, None); } + + // ============= SGLang Extension Field Pass-Through Tests ============= + + #[test] + fn test_chat_completion_sglang_extensions_passed_through() { + let messages = vec![ChatMessage::User { + role: "user".to_string(), + content: UserMessageContent::Text("Test".to_string()), + name: None, + }]; + + let mut session_params = std::collections::HashMap::new(); + session_params.insert( + "key".to_string(), + serde_json::Value::String("value".to_string()), + ); + + let req = ChatCompletionRequest { + messages, + model: "test-model".to_string(), + // SGLang Extensions - Priority 1 + top_k: Some(40), + min_p: Some(0.05), + min_tokens: Some(10), + repetition_penalty: Some(1.1), + regex: Some("test_regex".to_string()), + ebnf: Some("test_ebnf".to_string()), + stop_token_ids: Some(vec![1, 2, 3]), + // SGLang Extensions - Priority 2 + lora_path: Some(LoRAPath::Single(Some("test_lora.bin".to_string()))), + session_params: Some(session_params.clone()), + // Boolean extensions (ChatCompletionRequest has these as bool, not Option) + no_stop_trim: true, + ignore_eos: false, + continue_final_message: true, + skip_special_tokens: false, + separate_reasoning: true, + stream_reasoning: false, + return_hidden_states: true, + ..default_chat_completion_request() + }; + + let pd_req = req.to_pd_request(); + let other = pd_req.other.as_object().unwrap(); + + // Verify SGLang extensions are passed through + assert_eq!(other.get("top_k"), Some(&json!(40))); + assert!((other.get("min_p").unwrap().as_f64().unwrap() - 0.05).abs() < 0.0001); + assert_eq!(other.get("min_tokens"), Some(&json!(10))); + assert!((other.get("repetition_penalty").unwrap().as_f64().unwrap() - 1.1).abs() < 0.0001); + assert_eq!(other.get("regex"), Some(&json!("test_regex"))); + assert_eq!(other.get("ebnf"), Some(&json!("test_ebnf"))); + assert_eq!(other.get("stop_token_ids"), Some(&json!(vec![1, 2, 3]))); + assert_eq!(other.get("lora_path"), Some(&json!("test_lora.bin"))); + assert_eq!( + other.get("session_params"), + Some(&serde_json::to_value(&session_params).unwrap()) + ); + + // Verify boolean extensions + assert_eq!(other.get("no_stop_trim"), Some(&json!(true))); + assert_eq!(other.get("ignore_eos"), Some(&json!(false))); + assert_eq!(other.get("continue_final_message"), Some(&json!(true))); + assert_eq!(other.get("skip_special_tokens"), Some(&json!(false))); + assert_eq!(other.get("separate_reasoning"), Some(&json!(true))); + assert_eq!(other.get("stream_reasoning"), Some(&json!(false))); + assert_eq!(other.get("return_hidden_states"), Some(&json!(true))); + } + + #[test] + fn test_completion_request_sglang_extensions_passed_through() { + let mut session_params = std::collections::HashMap::new(); + session_params.insert( + "key".to_string(), + serde_json::Value::String("value".to_string()), + ); + + let req = CompletionRequest { + prompt: StringOrArray::String("Test prompt".to_string()), + model: "test-model".to_string(), + // SGLang Extensions - Priority 1 + top_k: Some(40), + min_p: Some(0.05), + min_tokens: Some(10), + repetition_penalty: Some(1.1), + regex: Some("test_regex".to_string()), + ebnf: Some("test_ebnf".to_string()), + stop_token_ids: Some(vec![1, 2, 3]), + // SGLang Extensions - Priority 2 + lora_path: Some(LoRAPath::Single(Some("test_lora.bin".to_string()))), + session_params: Some(session_params.clone()), + // Boolean extensions (CompletionRequest only has these 4 boolean fields) + no_stop_trim: true, + ignore_eos: false, + skip_special_tokens: false, + return_hidden_states: true, + ..default_completion_request() + }; + + let pd_req = req.to_pd_request(); + let other = pd_req.other.as_object().unwrap(); + + // Verify SGLang extensions are passed through + assert_eq!(other.get("top_k"), Some(&json!(40))); + assert!((other.get("min_p").unwrap().as_f64().unwrap() - 0.05).abs() < 0.0001); + assert_eq!(other.get("min_tokens"), Some(&json!(10))); + assert!((other.get("repetition_penalty").unwrap().as_f64().unwrap() - 1.1).abs() < 0.0001); + assert_eq!(other.get("regex"), Some(&json!("test_regex"))); + assert_eq!(other.get("ebnf"), Some(&json!("test_ebnf"))); + assert_eq!(other.get("stop_token_ids"), Some(&json!(vec![1, 2, 3]))); + assert_eq!(other.get("lora_path"), Some(&json!("test_lora.bin"))); + assert_eq!( + other.get("session_params"), + Some(&serde_json::to_value(&session_params).unwrap()) + ); + + // Verify boolean extensions (only the ones CompletionRequest has) + assert_eq!(other.get("no_stop_trim"), Some(&json!(true))); + assert_eq!(other.get("ignore_eos"), Some(&json!(false))); + assert_eq!(other.get("skip_special_tokens"), Some(&json!(false))); + assert_eq!(other.get("return_hidden_states"), Some(&json!(true))); + } + + #[test] + fn test_sglang_extensions_none_values_not_passed_through() { + let messages = vec![ChatMessage::User { + role: "user".to_string(), + content: UserMessageContent::Text("Test".to_string()), + name: None, + }]; + + let req = ChatCompletionRequest { + messages, + model: "test-model".to_string(), + // All SGLang extensions as None/default - Optional fields won't appear, bools will use defaults + top_k: None, + min_p: None, + min_tokens: None, + repetition_penalty: None, + regex: None, + ebnf: None, + stop_token_ids: None, + lora_path: None, + session_params: None, + // Boolean fields use defaults (false for most, true for some with default_true) + no_stop_trim: false, + ignore_eos: false, + continue_final_message: false, + skip_special_tokens: true, // This has default_true + separate_reasoning: true, // This has default_true + stream_reasoning: true, // This has default_true + return_hidden_states: false, + ..default_chat_completion_request() + }; + + let pd_req = req.to_pd_request(); + let other = pd_req.other.as_object().unwrap(); + + // Verify None values are not included + assert!(!other.contains_key("top_k")); + assert!(!other.contains_key("min_p")); + assert!(!other.contains_key("min_tokens")); + assert!(!other.contains_key("repetition_penalty")); + assert!(!other.contains_key("regex")); + assert!(!other.contains_key("ebnf")); + assert!(!other.contains_key("stop_token_ids")); + assert!(!other.contains_key("lora_path")); + assert!(!other.contains_key("session_params")); + + // Boolean fields are always present with their values (can't be None) + assert_eq!(other.get("no_stop_trim"), Some(&json!(false))); + assert_eq!(other.get("ignore_eos"), Some(&json!(false))); + assert_eq!(other.get("continue_final_message"), Some(&json!(false))); + assert_eq!(other.get("skip_special_tokens"), Some(&json!(true))); // default_true + assert_eq!(other.get("separate_reasoning"), Some(&json!(true))); // default_true + assert_eq!(other.get("stream_reasoning"), Some(&json!(true))); // default_true + assert_eq!(other.get("return_hidden_states"), Some(&json!(false))); + } } diff --git a/sgl-router/tests/benchmark_integration.rs b/sgl-router/tests/benchmark_integration.rs index b7876e223..75c55986f 100644 --- a/sgl-router/tests/benchmark_integration.rs +++ b/sgl-router/tests/benchmark_integration.rs @@ -8,14 +8,118 @@ use sglang_router_rs::openai_api_types::{ }; use sglang_router_rs::routers::request_adapter::{RouteableRequest, ToPdRequest}; +/// Create a default GenerateRequest for benchmarks with minimal fields set +fn default_generate_request() -> GenerateRequest { + GenerateRequest { + text: None, + prompt: None, + input_ids: None, + stream: false, + parameters: None, + sampling_params: None, + return_logprob: false, + // SGLang Extensions + lora_path: None, + session_params: None, + return_hidden_states: false, + rid: None, + } +} + +/// Create a default ChatCompletionRequest for benchmarks with minimal fields set +fn default_chat_completion_request() -> ChatCompletionRequest { + ChatCompletionRequest { + model: String::new(), + messages: vec![], + max_tokens: None, + max_completion_tokens: None, + temperature: None, + top_p: None, + n: None, + stream: false, + stream_options: None, + stop: None, + presence_penalty: None, + frequency_penalty: None, + logit_bias: None, + logprobs: false, + top_logprobs: None, + user: None, + response_format: None, + seed: None, + tools: None, + tool_choice: None, + parallel_tool_calls: None, + function_call: None, + functions: None, + // SGLang Extensions + top_k: None, + min_p: None, + min_tokens: None, + repetition_penalty: None, + regex: None, + ebnf: None, + stop_token_ids: None, + no_stop_trim: false, + ignore_eos: false, + continue_final_message: false, + skip_special_tokens: true, + // SGLang Extensions + lora_path: None, + session_params: None, + separate_reasoning: true, + stream_reasoning: true, + return_hidden_states: false, + } +} + +/// Create a default CompletionRequest for benchmarks with minimal fields set +fn default_completion_request() -> CompletionRequest { + CompletionRequest { + model: String::new(), + prompt: StringOrArray::String(String::new()), + suffix: None, + max_tokens: None, + temperature: None, + top_p: None, + n: None, + stream: false, + stream_options: None, + logprobs: None, + echo: false, + stop: None, + presence_penalty: None, + frequency_penalty: None, + best_of: None, + logit_bias: None, + user: None, + seed: None, + // SGLang Extensions + top_k: None, + min_p: None, + min_tokens: None, + repetition_penalty: None, + regex: None, + ebnf: None, + json_schema: None, + stop_token_ids: None, + no_stop_trim: false, + ignore_eos: false, + skip_special_tokens: true, + // SGLang Extensions + lora_path: None, + session_params: None, + return_hidden_states: false, + other: serde_json::Map::new(), + } +} + #[test] fn test_benchmark_request_creation() { // Ensure all benchmark request types can be created without panicking let generate_req = GenerateRequest { text: Some("Test prompt".to_string()), - input_ids: None, - prompt: None, parameters: Some(GenerateParameters { max_new_tokens: Some(100), temperature: Some(0.8), @@ -33,8 +137,7 @@ fn test_benchmark_request_creation() { repetition_penalty: Some(1.0), ..Default::default() }), - stream: false, - return_logprob: false, + ..default_generate_request() }; let chat_req = ChatCompletionRequest { @@ -49,44 +152,23 @@ fn test_benchmark_request_creation() { temperature: Some(0.7), top_p: Some(1.0), n: Some(1), - stream: false, - stream_options: None, - stop: None, presence_penalty: Some(0.0), frequency_penalty: Some(0.0), - logit_bias: None, - logprobs: false, - top_logprobs: None, - user: None, - response_format: None, - seed: None, - tools: None, - tool_choice: None, parallel_tool_calls: Some(true), - function_call: None, - functions: None, + ..default_chat_completion_request() }; let completion_req = CompletionRequest { model: "test-model".to_string(), prompt: StringOrArray::String("Test prompt".to_string()), - suffix: None, max_tokens: Some(50), temperature: Some(0.8), top_p: Some(1.0), n: Some(1), - stream: false, - stream_options: None, - logprobs: None, - echo: false, - stop: None, presence_penalty: Some(0.0), frequency_penalty: Some(0.0), best_of: Some(1), - logit_bias: None, - user: None, - seed: None, - other: serde_json::Map::new(), + ..default_completion_request() }; // Test serialization works @@ -101,12 +183,7 @@ fn test_benchmark_serialization_roundtrip() { let generate_req = GenerateRequest { text: Some("Test prompt".to_string()), - input_ids: None, - prompt: None, - parameters: None, - sampling_params: None, - stream: false, - return_logprob: false, + ..default_generate_request() }; // Serialize and deserialize @@ -125,12 +202,7 @@ fn test_benchmark_request_adaptation() { let generate_req = GenerateRequest { text: Some("Test prompt".to_string()), - input_ids: None, - prompt: None, - parameters: None, - sampling_params: None, - stream: false, - return_logprob: false, + ..default_generate_request() }; let chat_req = ChatCompletionRequest { @@ -145,44 +217,23 @@ fn test_benchmark_request_adaptation() { temperature: Some(0.7), top_p: Some(1.0), n: Some(1), - stream: false, - stream_options: None, - stop: None, presence_penalty: Some(0.0), frequency_penalty: Some(0.0), - logit_bias: None, - logprobs: false, - top_logprobs: None, - user: None, - response_format: None, - seed: None, - tools: None, - tool_choice: None, parallel_tool_calls: Some(true), - function_call: None, - functions: None, + ..default_chat_completion_request() }; let completion_req = CompletionRequest { model: "test-model".to_string(), prompt: StringOrArray::String("Test prompt".to_string()), - suffix: None, max_tokens: Some(50), temperature: Some(0.8), top_p: Some(1.0), n: Some(1), - stream: false, - stream_options: None, - logprobs: None, - echo: false, - stop: None, presence_penalty: Some(0.0), frequency_penalty: Some(0.0), best_of: Some(1), - logit_bias: None, - user: None, - seed: None, - other: serde_json::Map::new(), + ..default_completion_request() }; // Test PD adaptation (should not panic) @@ -197,12 +248,7 @@ fn test_benchmark_regular_routing() { let generate_req = GenerateRequest { text: Some("Test prompt".to_string()), - input_ids: None, - prompt: None, - parameters: None, - sampling_params: None, - stream: false, - return_logprob: false, + ..default_generate_request() }; // Test regular routing methods (should not panic) @@ -217,12 +263,7 @@ fn test_benchmark_performance_baseline() { let generate_req = GenerateRequest { text: Some("Short test prompt".to_string()), - input_ids: None, - prompt: None, - parameters: None, - sampling_params: None, - stream: false, - return_logprob: false, + ..default_generate_request() }; // Serialization should be fast (< 1ms for simple requests)