From d2478cd4ff857eb93d73b754adfd0eef5d67d156 Mon Sep 17 00:00:00 2001 From: Keyang Ru Date: Wed, 15 Oct 2025 09:59:38 -0700 Subject: [PATCH] [router] Fix response api related spec (#11621) --- sgl-router/src/middleware.rs | 2 + sgl-router/src/protocols/spec.rs | 129 +++---------------------- sgl-router/tests/api_endpoints_test.rs | 2 +- sgl-router/tests/responses_api_test.rs | 49 ++++++---- sgl-router/tests/spec/rerank.rs | 8 +- 5 files changed, 50 insertions(+), 140 deletions(-) diff --git a/sgl-router/src/middleware.rs b/sgl-router/src/middleware.rs index 6e6344900..924edcee6 100644 --- a/sgl-router/src/middleware.rs +++ b/sgl-router/src/middleware.rs @@ -69,6 +69,8 @@ fn generate_request_id(path: &str) -> String { "cmpl-" } else if path.contains("/generate") { "gnt-" + } else if path.contains("/responses") { + "resp-" } else { "req-" }; diff --git a/sgl-router/src/protocols/spec.rs b/sgl-router/src/protocols/spec.rs index 982a3026e..394b0d28d 100644 --- a/sgl-router/src/protocols/spec.rs +++ b/sgl-router/src/protocols/spec.rs @@ -1,5 +1,5 @@ use serde::{Deserialize, Serialize}; -use serde_json::{to_value, Map, Number, Value}; +use serde_json::{to_value, Map, Value}; use std::collections::HashMap; use validator::Validate; @@ -1325,10 +1325,6 @@ impl ResponsesUsage { } } -fn generate_request_id() -> String { - format!("resp_{}", uuid::Uuid::new_v4().simple()) -} - #[derive(Debug, Clone, Deserialize, Serialize)] pub struct ResponsesRequest { /// Run the request in the background @@ -1419,8 +1415,8 @@ pub struct ResponsesRequest { pub user: Option, /// Request ID - #[serde(default = "generate_request_id")] - pub request_id: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub request_id: Option, /// Request priority #[serde(default)] @@ -1438,15 +1434,15 @@ pub struct ResponsesRequest { #[serde(skip_serializing_if = "Option::is_none")] pub stop: Option, - /// Top-k sampling parameter + /// Top-k sampling parameter (SGLang extension) #[serde(default = "default_top_k")] pub top_k: i32, - /// Min-p sampling parameter + /// Min-p sampling parameter (SGLang extension) #[serde(default)] pub min_p: f32, - /// Repetition penalty + /// Repetition penalty (SGLang extension) #[serde(default = "default_repetition_penalty")] pub repetition_penalty: f32, } @@ -1491,7 +1487,7 @@ impl Default for ResponsesRequest { top_p: None, truncation: None, user: None, - request_id: generate_request_id(), + request_id: None, priority: 0, frequency_penalty: None, presence_penalty: None, @@ -1503,101 +1499,6 @@ impl Default for ResponsesRequest { } } -impl ResponsesRequest { - /// Default sampling parameters - const DEFAULT_TEMPERATURE: f32 = 0.7; - const DEFAULT_TOP_P: f32 = 1.0; - - /// Convert to sampling parameters for generation - pub fn to_sampling_params( - &self, - default_max_tokens: u32, - default_params: Option>, - ) -> HashMap { - let mut params = HashMap::new(); - - // Use max_output_tokens if available - let max_tokens = if let Some(max_output) = self.max_output_tokens { - std::cmp::min(max_output, default_max_tokens) - } else { - default_max_tokens - }; - - // Avoid exceeding context length by minus 1 token - let max_tokens = max_tokens.saturating_sub(1); - - // Temperature - let temperature = self.temperature.unwrap_or_else(|| { - default_params - .as_ref() - .and_then(|p| p.get("temperature")) - .and_then(|v| v.as_f64()) - .map(|v| v as f32) - .unwrap_or(Self::DEFAULT_TEMPERATURE) - }); - - // Top-p - let top_p = self.top_p.unwrap_or_else(|| { - default_params - .as_ref() - .and_then(|p| p.get("top_p")) - .and_then(|v| v.as_f64()) - .map(|v| v as f32) - .unwrap_or(Self::DEFAULT_TOP_P) - }); - - params.insert( - "max_new_tokens".to_string(), - Value::Number(Number::from(max_tokens)), - ); - params.insert( - "temperature".to_string(), - Value::Number(Number::from_f64(temperature as f64).unwrap()), - ); - params.insert( - "top_p".to_string(), - Value::Number(Number::from_f64(top_p as f64).unwrap()), - ); - if let Some(fp) = self.frequency_penalty { - params.insert( - "frequency_penalty".to_string(), - Value::Number(Number::from_f64(fp as f64).unwrap()), - ); - } - if let Some(pp) = self.presence_penalty { - params.insert( - "presence_penalty".to_string(), - Value::Number(Number::from_f64(pp as f64).unwrap()), - ); - } - params.insert("top_k".to_string(), Value::Number(Number::from(self.top_k))); - params.insert( - "min_p".to_string(), - Value::Number(Number::from_f64(self.min_p as f64).unwrap()), - ); - params.insert( - "repetition_penalty".to_string(), - Value::Number(Number::from_f64(self.repetition_penalty as f64).unwrap()), - ); - - if let Some(ref stop) = self.stop { - match to_value(stop) { - Ok(value) => params.insert("stop".to_string(), value), - Err(_) => params.insert("stop".to_string(), Value::Null), - }; - } - - // Apply any additional default parameters - if let Some(default_params) = default_params { - for (key, value) in default_params { - params.entry(key).or_insert(value); - } - } - - params - } -} - impl GenerationRequest for ResponsesRequest { fn is_stream(&self) -> bool { self.stream.unwrap_or(false) @@ -1776,7 +1677,10 @@ impl ResponsesResponse { usage: Option, ) -> Self { Self { - id: request.request_id.clone(), + id: request + .request_id + .clone() + .expect("request_id should be set by middleware"), object: "response".to_string(), created_at: created_time, status, @@ -2535,9 +2439,6 @@ pub enum GenerateFinishReason { Other(Value), } -// Constants for rerank API -pub const DEFAULT_MODEL_NAME: &str = "default"; - /// Rerank request for scoring documents against a query /// Used for RAG systems and document relevance scoring #[derive(Debug, Clone, Serialize, Deserialize)] @@ -2549,7 +2450,7 @@ pub struct RerankRequest { pub documents: Vec, /// Model to use for reranking - #[serde(default = "default_model_name")] + #[serde(default = "default_model")] pub model: String, /// Maximum number of documents to return (optional) @@ -2567,10 +2468,6 @@ pub struct RerankRequest { pub user: Option, } -pub fn default_model_name() -> String { - DEFAULT_MODEL_NAME.to_string() -} - fn default_return_documents() -> bool { true } @@ -2634,7 +2531,7 @@ impl From for RerankRequest { RerankRequest { query: v1.query, documents: v1.documents, - model: default_model_name(), + model: default_model(), top_k: None, return_documents: true, rid: None, diff --git a/sgl-router/tests/api_endpoints_test.rs b/sgl-router/tests/api_endpoints_test.rs index 44392a5e2..ffdb4997c 100644 --- a/sgl-router/tests/api_endpoints_test.rs +++ b/sgl-router/tests/api_endpoints_test.rs @@ -2156,7 +2156,7 @@ mod rerank_tests { assert!(body_json.get("model").is_some()); // V1 API should use default model name - assert_eq!(body_json["model"], "default"); + assert_eq!(body_json["model"], "unknown"); let results = body_json["results"].as_array().unwrap(); assert_eq!(results.len(), 3); // All documents should be returned diff --git a/sgl-router/tests/responses_api_test.rs b/sgl-router/tests/responses_api_test.rs index cd155baf8..60ab83c9f 100644 --- a/sgl-router/tests/responses_api_test.rs +++ b/sgl-router/tests/responses_api_test.rs @@ -115,7 +115,7 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() { top_p: None, truncation: Some(Truncation::Disabled), user: None, - request_id: "resp_test_mcp_e2e".to_string(), + request_id: Some("resp_test_mcp_e2e".to_string()), priority: 0, frequency_penalty: Some(0.0), presence_penalty: Some(0.0), @@ -361,7 +361,7 @@ fn test_responses_request_creation() { top_p: Some(0.9), truncation: Some(Truncation::Disabled), user: Some("test-user".to_string()), - request_id: "resp_test123".to_string(), + request_id: Some("resp_test123".to_string()), priority: 0, frequency_penalty: Some(0.0), presence_penalty: Some(0.0), @@ -379,7 +379,8 @@ fn test_responses_request_creation() { } #[test] -fn test_sampling_params_conversion() { +fn test_responses_request_sglang_extensions() { + // Test that SGLang-specific sampling parameters are present and serializable let request = ResponsesRequest { background: Some(false), include: None, @@ -389,37 +390,44 @@ fn test_sampling_params_conversion() { max_tool_calls: None, metadata: None, model: Some("test-model".to_string()), - parallel_tool_calls: Some(true), // Use default true + parallel_tool_calls: Some(true), previous_response_id: None, reasoning: None, service_tier: Some(ServiceTier::Auto), - store: Some(true), // Use default true + store: Some(true), stream: Some(false), temperature: Some(0.8), tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Auto)), tools: Some(vec![]), - top_logprobs: Some(0), // Use default 0 + top_logprobs: Some(0), top_p: Some(0.95), truncation: Some(Truncation::Auto), user: None, - request_id: "resp_test456".to_string(), + request_id: Some("resp_test456".to_string()), priority: 0, frequency_penalty: Some(0.1), presence_penalty: Some(0.2), stop: None, + // SGLang-specific extensions: top_k: 10, min_p: 0.05, repetition_penalty: 1.1, conversation: None, }; - let params = request.to_sampling_params(1000, None); + // Verify SGLang extensions are present + assert_eq!(request.top_k, 10); + assert_eq!(request.min_p, 0.05); + assert_eq!(request.repetition_penalty, 1.1); - // Check that parameters are converted correctly - assert!(params.contains_key("temperature")); - assert!(params.contains_key("top_p")); - assert!(params.contains_key("frequency_penalty")); - assert!(params.contains_key("max_new_tokens")); + // Verify serialization works with SGLang extensions + let json = serde_json::to_string(&request).expect("Serialization should work"); + let parsed: ResponsesRequest = + serde_json::from_str(&json).expect("Deserialization should work"); + + assert_eq!(parsed.top_k, 10); + assert_eq!(parsed.min_p, 0.05); + assert_eq!(parsed.repetition_penalty, 1.1); } #[test] @@ -516,7 +524,7 @@ fn test_json_serialization() { top_p: Some(0.8), truncation: Some(Truncation::Auto), user: Some("test_user".to_string()), - request_id: "resp_comprehensive_test".to_string(), + request_id: Some("resp_comprehensive_test".to_string()), priority: 1, frequency_penalty: Some(0.3), presence_penalty: Some(0.4), @@ -531,7 +539,10 @@ fn test_json_serialization() { let parsed: ResponsesRequest = serde_json::from_str(&json).expect("Deserialization should work"); - assert_eq!(parsed.request_id, "resp_comprehensive_test"); + assert_eq!( + parsed.request_id, + Some("resp_comprehensive_test".to_string()) + ); assert_eq!(parsed.model, Some("gpt-4".to_string())); assert_eq!(parsed.background, Some(true)); assert_eq!(parsed.stream, Some(true)); @@ -643,7 +654,7 @@ async fn test_multi_turn_loop_with_mcp() { top_p: Some(1.0), truncation: Some(Truncation::Disabled), user: None, - request_id: "resp_multi_turn_test".to_string(), + request_id: Some("resp_multi_turn_test".to_string()), priority: 0, frequency_penalty: Some(0.0), presence_penalty: Some(0.0), @@ -816,7 +827,7 @@ async fn test_max_tool_calls_limit() { top_p: Some(1.0), truncation: Some(Truncation::Disabled), user: None, - request_id: "resp_max_calls_test".to_string(), + request_id: Some("resp_max_calls_test".to_string()), priority: 0, frequency_penalty: Some(0.0), presence_penalty: Some(0.0), @@ -1011,7 +1022,7 @@ async fn test_streaming_with_mcp_tool_calls() { top_p: Some(1.0), truncation: Some(Truncation::Disabled), user: None, - request_id: "resp_streaming_mcp_test".to_string(), + request_id: Some("resp_streaming_mcp_test".to_string()), priority: 0, frequency_penalty: Some(0.0), presence_penalty: Some(0.0), @@ -1290,7 +1301,7 @@ async fn test_streaming_multi_turn_with_mcp() { top_p: Some(1.0), truncation: Some(Truncation::Disabled), user: None, - request_id: "resp_streaming_multiturn_test".to_string(), + request_id: Some("resp_streaming_multiturn_test".to_string()), priority: 0, frequency_penalty: Some(0.0), presence_penalty: Some(0.0), diff --git a/sgl-router/tests/spec/rerank.rs b/sgl-router/tests/spec/rerank.rs index 88b296298..3a0ca9aa8 100644 --- a/sgl-router/tests/spec/rerank.rs +++ b/sgl-router/tests/spec/rerank.rs @@ -1,7 +1,7 @@ use serde_json::{from_str, to_string, Number, Value}; use sglang_router_rs::protocols::spec::{ - default_model_name, GenerationRequest, RerankRequest, RerankResponse, RerankResult, - StringOrArray, UsageInfo, V1RerankReqInput, + GenerationRequest, RerankRequest, RerankResponse, RerankResult, StringOrArray, UsageInfo, + V1RerankReqInput, }; use std::collections::HashMap; @@ -40,7 +40,7 @@ fn test_rerank_request_deserialization_with_defaults() { assert_eq!(request.query, "test query"); assert_eq!(request.documents, vec!["doc1", "doc2"]); - assert_eq!(request.model, default_model_name()); + assert_eq!(request.model, "unknown"); assert_eq!(request.top_k, None); assert!(request.return_documents); assert_eq!(request.rid, None); @@ -414,7 +414,7 @@ fn test_v1_to_rerank_request_conversion() { assert_eq!(request.query, "test query"); assert_eq!(request.documents, vec!["doc1", "doc2"]); - assert_eq!(request.model, default_model_name()); + assert_eq!(request.model, "unknown"); assert_eq!(request.top_k, None); assert!(request.return_documents); assert_eq!(request.rid, None);