diff --git a/sgl-router/src/data_connector/response_memory_store.rs b/sgl-router/src/data_connector/response_memory_store.rs index 003d07ac7..764d7568c 100644 --- a/sgl-router/src/data_connector/response_memory_store.rs +++ b/sgl-router/src/data_connector/response_memory_store.rs @@ -74,13 +74,16 @@ impl ResponseStorage for MemoryResponseStorage { // Store the response store.responses.insert(response_id.clone(), response); + tracing::info!("memory_store_size" = store.responses.len()); Ok(response_id) } async fn get_response(&self, response_id: &ResponseId) -> Result> { let store = self.store.read(); - Ok(store.responses.get(response_id).cloned()) + let result = store.responses.get(response_id).cloned(); + tracing::info!("memory_get_response" = %response_id.0, found = result.is_some()); + Ok(result) } async fn delete_response(&self, response_id: &ResponseId) -> Result<()> { @@ -200,6 +203,20 @@ pub struct MemoryStoreStats { mod tests { use super::*; + #[tokio::test] + async fn test_store_with_custom_id() { + let store = MemoryResponseStorage::new(); + let mut response = StoredResponse::new("Input".to_string(), "Output".to_string(), None); + response.id = ResponseId::from_string("resp_custom".to_string()); + store.store_response(response.clone()).await.unwrap(); + let retrieved = store + .get_response(&ResponseId::from_string("resp_custom".to_string())) + .await + .unwrap(); + assert!(retrieved.is_some()); + assert_eq!(retrieved.unwrap().output, "Output"); + } + #[tokio::test] async fn test_memory_store_basic() { let store = MemoryResponseStorage::new(); diff --git a/sgl-router/src/data_connector/responses.rs b/sgl-router/src/data_connector/responses.rs index 49693e984..175311ef8 100644 --- a/sgl-router/src/data_connector/responses.rs +++ b/sgl-router/src/data_connector/responses.rs @@ -1,5 +1,6 @@ use async_trait::async_trait; use serde::{Deserialize, Serialize}; +use serde_json::Value; use std::collections::HashMap; use std::sync::Arc; @@ -55,6 +56,10 @@ pub struct StoredResponse { /// Model used for generation pub model: Option, + + /// Raw OpenAI response payload + #[serde(default)] + pub raw_response: Value, } impl StoredResponse { @@ -70,6 +75,7 @@ impl StoredResponse { created_at: chrono::Utc::now(), user: None, model: None, + raw_response: Value::Null, } } } @@ -175,3 +181,9 @@ pub trait ResponseStorage: Send + Sync { /// Type alias for shared storage pub type SharedResponseStorage = Arc; + +impl Default for StoredResponse { + fn default() -> Self { + Self::new(String::new(), String::new(), None) + } +} diff --git a/sgl-router/src/protocols/spec.rs b/sgl-router/src/protocols/spec.rs index 8e1d483ae..7e5bca611 100644 --- a/sgl-router/src/protocols/spec.rs +++ b/sgl-router/src/protocols/spec.rs @@ -1,5 +1,5 @@ use serde::{Deserialize, Serialize}; -use serde_json::Value; +use serde_json::{to_value, Map, Number, Value}; use std::collections::HashMap; // # Protocol Specifications @@ -350,7 +350,7 @@ pub struct ChatCompletionRequest { /// Session parameters for continual prompting #[serde(skip_serializing_if = "Option::is_none")] - pub session_params: Option>, + pub session_params: Option>, /// Separate reasoning content from final answer (O1-style models) #[serde(default = "default_true")] @@ -362,7 +362,7 @@ pub struct ChatCompletionRequest { /// Chat template kwargs #[serde(skip_serializing_if = "Option::is_none")] - pub chat_template_kwargs: Option>, + pub chat_template_kwargs: Option>, /// Return model hidden states #[serde(default)] @@ -447,7 +447,7 @@ pub struct ChatChoice { pub finish_reason: Option, // "stop", "length", "tool_calls", "content_filter", "function_call" /// Information about which stop condition was matched #[serde(skip_serializing_if = "Option::is_none")] - pub matched_stop: Option, // Can be string or integer + pub matched_stop: Option, // Can be string or integer /// Hidden states from the model (SGLang extension) #[serde(skip_serializing_if = "Option::is_none")] pub hidden_states: Option>, @@ -606,7 +606,7 @@ pub struct CompletionRequest { /// Session parameters for continual prompting #[serde(skip_serializing_if = "Option::is_none")] - pub session_params: Option>, + pub session_params: Option>, /// Return model hidden states #[serde(default)] @@ -618,7 +618,7 @@ pub struct CompletionRequest { /// Additional fields including bootstrap info for PD routing #[serde(flatten)] - pub other: serde_json::Map, + pub other: Map, } impl GenerationRequest for CompletionRequest { @@ -662,7 +662,7 @@ pub struct CompletionChoice { pub finish_reason: Option, // "stop", "length", "content_filter", etc. /// Information about which stop condition was matched #[serde(skip_serializing_if = "Option::is_none")] - pub matched_stop: Option, // Can be string or integer + pub matched_stop: Option, // Can be string or integer /// Hidden states from the model (SGLang extension) #[serde(skip_serializing_if = "Option::is_none")] pub hidden_states: Option>, @@ -776,6 +776,10 @@ pub enum ResponseContentPart { #[serde(skip_serializing_if = "Option::is_none")] logprobs: Option, }, + #[serde(rename = "input_text")] + InputText { text: String }, + #[serde(other)] + Unknown, } #[derive(Debug, Clone, Deserialize, Serialize)] @@ -864,6 +868,29 @@ pub enum ResponseStatus { Cancelled, } +// ============= Reasoning Info ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ReasoningInfo { + #[serde(skip_serializing_if = "Option::is_none")] + pub effort: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub summary: Option, +} + +// ============= Text Format ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ResponseTextFormat { + pub format: TextFormatType, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct TextFormatType { + #[serde(rename = "type")] + pub format_type: String, +} + // ============= Include Fields ============= #[derive(Debug, Clone, Deserialize, Serialize)] @@ -915,6 +942,13 @@ pub struct ResponseUsage { pub output_tokens_details: Option, } +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(untagged)] +pub enum ResponsesUsage { + Classic(UsageInfo), + Modern(ResponseUsage), +} + #[derive(Debug, Clone, Deserialize, Serialize)] pub struct InputTokensDetails { pub cached_tokens: u32, @@ -970,6 +1004,34 @@ impl ResponseUsage { } } +#[derive(Debug, Clone, Default, Deserialize, Serialize)] +pub struct ResponsesGetParams { + #[serde(default)] + pub include: Vec, + #[serde(default)] + pub include_obfuscation: Option, + #[serde(default)] + pub starting_after: Option, + #[serde(default)] + pub stream: Option, +} + +impl ResponsesUsage { + pub fn to_response_usage(&self) -> ResponseUsage { + match self { + ResponsesUsage::Classic(usage) => usage.to_response_usage(), + ResponsesUsage::Modern(usage) => usage.clone(), + } + } + + pub fn to_usage_info(&self) -> UsageInfo { + match self { + ResponsesUsage::Classic(usage) => usage.clone(), + ResponsesUsage::Modern(usage) => usage.to_usage_info(), + } + } +} + fn generate_request_id() -> String { format!("resp_{}", uuid::Uuid::new_v4().simple()) } @@ -1002,7 +1064,7 @@ pub struct ResponsesRequest { /// Additional metadata #[serde(skip_serializing_if = "Option::is_none")] - pub metadata: Option>, + pub metadata: Option>, /// Model to use (optional to match vLLM) #[serde(skip_serializing_if = "Option::is_none")] @@ -1109,6 +1171,42 @@ fn default_repetition_penalty() -> f32 { 1.0 } +impl Default for ResponsesRequest { + fn default() -> Self { + Self { + background: false, + include: None, + input: ResponseInput::Text(String::new()), + instructions: None, + max_output_tokens: None, + max_tool_calls: None, + metadata: None, + model: None, + parallel_tool_calls: true, + previous_response_id: None, + reasoning: None, + service_tier: ServiceTier::default(), + store: true, + stream: false, + temperature: None, + tool_choice: ToolChoice::default(), + tools: Vec::new(), + top_logprobs: 0, + top_p: None, + truncation: Truncation::default(), + user: None, + request_id: generate_request_id(), + priority: 0, + frequency_penalty: 0.0, + presence_penalty: 0.0, + stop: None, + top_k: default_top_k(), + min_p: 0.0, + repetition_penalty: default_repetition_penalty(), + } + } +} + impl ResponsesRequest { /// Default sampling parameters const DEFAULT_TEMPERATURE: f32 = 0.7; @@ -1118,8 +1216,8 @@ impl ResponsesRequest { pub fn to_sampling_params( &self, default_max_tokens: u32, - default_params: Option>, - ) -> HashMap { + default_params: Option>, + ) -> HashMap { let mut params = HashMap::new(); // Use max_output_tokens if available @@ -1154,47 +1252,38 @@ impl ResponsesRequest { params.insert( "max_new_tokens".to_string(), - serde_json::Value::Number(serde_json::Number::from(max_tokens)), + Value::Number(Number::from(max_tokens)), ); params.insert( "temperature".to_string(), - serde_json::Value::Number(serde_json::Number::from_f64(temperature as f64).unwrap()), + Value::Number(Number::from_f64(temperature as f64).unwrap()), ); params.insert( "top_p".to_string(), - serde_json::Value::Number(serde_json::Number::from_f64(top_p as f64).unwrap()), + Value::Number(Number::from_f64(top_p as f64).unwrap()), ); params.insert( "frequency_penalty".to_string(), - serde_json::Value::Number( - serde_json::Number::from_f64(self.frequency_penalty as f64).unwrap(), - ), + Value::Number(Number::from_f64(self.frequency_penalty as f64).unwrap()), ); params.insert( "presence_penalty".to_string(), - serde_json::Value::Number( - serde_json::Number::from_f64(self.presence_penalty as f64).unwrap(), - ), - ); - params.insert( - "top_k".to_string(), - serde_json::Value::Number(serde_json::Number::from(self.top_k)), + Value::Number(Number::from_f64(self.presence_penalty as f64).unwrap()), ); + params.insert("top_k".to_string(), Value::Number(Number::from(self.top_k))); params.insert( "min_p".to_string(), - serde_json::Value::Number(serde_json::Number::from_f64(self.min_p as f64).unwrap()), + Value::Number(Number::from_f64(self.min_p as f64).unwrap()), ); params.insert( "repetition_penalty".to_string(), - serde_json::Value::Number( - serde_json::Number::from_f64(self.repetition_penalty as f64).unwrap(), - ), + Value::Number(Number::from_f64(self.repetition_penalty as f64).unwrap()), ); if let Some(ref stop) = self.stop { - match serde_json::to_value(stop) { + match to_value(stop) { Ok(value) => params.insert("stop".to_string(), value), - Err(_) => params.insert("stop".to_string(), serde_json::Value::Null), + Err(_) => params.insert("stop".to_string(), Value::Null), }; } @@ -1227,8 +1316,10 @@ impl GenerationRequest for ResponsesRequest { ResponseInputOutputItem::Message { content, .. } => { let texts: Vec = content .iter() - .map(|part| match part { - ResponseContentPart::OutputText { text, .. } => text.clone(), + .filter_map(|part| match part { + ResponseContentPart::OutputText { text, .. } => Some(text.clone()), + ResponseContentPart::InputText { text } => Some(text.clone()), + ResponseContentPart::Unknown => None, }) .collect(); if texts.is_empty() { @@ -1285,6 +1376,25 @@ pub struct ResponsesResponse { #[serde(default = "current_timestamp")] pub created_at: i64, + /// Response status + pub status: ResponseStatus, + + /// Error information if status is failed + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, + + /// Incomplete details if response was truncated + #[serde(skip_serializing_if = "Option::is_none")] + pub incomplete_details: Option, + + /// System instructions used + #[serde(skip_serializing_if = "Option::is_none")] + pub instructions: Option, + + /// Max output tokens setting + #[serde(skip_serializing_if = "Option::is_none")] + pub max_output_tokens: Option, + /// Model name pub model: String, @@ -1292,17 +1402,30 @@ pub struct ResponsesResponse { #[serde(default)] pub output: Vec, - /// Response status - pub status: ResponseStatus, - - /// Usage information - #[serde(skip_serializing_if = "Option::is_none")] - pub usage: Option, - /// Whether parallel tool calls are enabled #[serde(default = "default_true")] pub parallel_tool_calls: bool, + /// Previous response ID if this is a continuation + #[serde(skip_serializing_if = "Option::is_none")] + pub previous_response_id: Option, + + /// Reasoning information + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning: Option, + + /// Whether the response is stored + #[serde(default = "default_true")] + pub store: bool, + + /// Temperature setting used + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + + /// Text format settings + #[serde(skip_serializing_if = "Option::is_none")] + pub text: Option, + /// Tool choice setting #[serde(default = "default_tool_choice")] pub tool_choice: String, @@ -1310,6 +1433,26 @@ pub struct ResponsesResponse { /// Available tools #[serde(default)] pub tools: Vec, + + /// Top-p setting used + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + + /// Truncation strategy used + #[serde(skip_serializing_if = "Option::is_none")] + pub truncation: Option, + + /// Usage information + #[serde(skip_serializing_if = "Option::is_none")] + pub usage: Option, + + /// User identifier + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, + + /// Additional metadata + #[serde(default)] + pub metadata: HashMap, } fn default_object_type() -> String { @@ -1325,7 +1468,7 @@ impl ResponsesResponse { #[allow(clippy::too_many_arguments)] pub fn from_request( request: &ResponsesRequest, - _sampling_params: &HashMap, + _sampling_params: &HashMap, model_name: String, created_time: i64, output: Vec, @@ -1336,11 +1479,26 @@ impl ResponsesResponse { id: request.request_id.clone(), object: "response".to_string(), created_at: created_time, + status, + error: None, + incomplete_details: None, + instructions: request.instructions.clone(), + max_output_tokens: request.max_output_tokens, model: model_name, output, - status, - usage, parallel_tool_calls: request.parallel_tool_calls, + previous_response_id: request.previous_response_id.clone(), + reasoning: request.reasoning.as_ref().map(|r| ReasoningInfo { + effort: r.effort.as_ref().map(|e| format!("{:?}", e)), + summary: None, + }), + store: request.store, + temperature: request.temperature, + text: Some(ResponseTextFormat { + format: TextFormatType { + format_type: "text".to_string(), + }, + }), tool_choice: match &request.tool_choice { ToolChoice::Value(ToolChoiceValue::Auto) => "auto".to_string(), ToolChoice::Value(ToolChoiceValue::Required) => "required".to_string(), @@ -1348,6 +1506,14 @@ impl ResponsesResponse { ToolChoice::Function { .. } => "function".to_string(), }, tools: request.tools.clone(), + top_p: request.top_p, + truncation: match &request.truncation { + Truncation::Auto => Some("auto".to_string()), + Truncation::Disabled => Some("disabled".to_string()), + }, + usage: usage.map(ResponsesUsage::Classic), + user: request.user.clone(), + metadata: request.metadata.clone().unwrap_or_default(), } } @@ -1357,13 +1523,26 @@ impl ResponsesResponse { id: request_id, object: "response".to_string(), created_at: current_timestamp(), + status, + error: None, + incomplete_details: None, + instructions: None, + max_output_tokens: None, model, output: Vec::new(), - status, - usage: None, parallel_tool_calls: true, + previous_response_id: None, + reasoning: None, + store: true, + temperature: None, + text: None, tool_choice: "auto".to_string(), tools: Vec::new(), + top_p: None, + truncation: None, + usage: None, + user: None, + metadata: HashMap::new(), } } @@ -1374,7 +1553,7 @@ impl ResponsesResponse { /// Set the usage information pub fn set_usage(&mut self, usage: UsageInfo) { - self.usage = Some(usage); + self.usage = Some(ResponsesUsage::Classic(usage)); } /// Update the status @@ -1413,12 +1592,12 @@ impl ResponsesResponse { } /// Get the response as a JSON value with usage in response format - pub fn to_response_format(&self) -> serde_json::Value { - let mut response = serde_json::to_value(self).unwrap_or(serde_json::Value::Null); + pub fn to_response_format(&self) -> Value { + let mut response = to_value(self).unwrap_or(Value::Null); // Convert usage to response format if present if let Some(usage) = &self.usage { - if let Ok(usage_value) = serde_json::to_value(usage.to_response_usage()) { + if let Ok(usage_value) = to_value(usage.to_response_usage()) { response["usage"] = usage_value; } } @@ -1641,8 +1820,13 @@ pub struct LogProbs { } #[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ChatLogProbs { - pub content: Option>, +#[serde(untagged)] +pub enum ChatLogProbs { + Detailed { + #[serde(skip_serializing_if = "Option::is_none")] + content: Option>, + }, + Raw(Value), } #[derive(Debug, Clone, Deserialize, Serialize)] @@ -1798,7 +1982,7 @@ pub struct GenerateRequest { /// Session parameters for continual prompting #[serde(skip_serializing_if = "Option::is_none")] - pub session_params: Option>, + pub session_params: Option>, /// Return model hidden states #[serde(default)] @@ -2065,7 +2249,7 @@ pub struct EmbeddingRequest { pub model: String, /// Input can be a string, array of strings, tokens, or batch inputs - pub input: serde_json::Value, + pub input: Value, /// Optional encoding format (e.g., "float", "base64") #[serde(skip_serializing_if = "Option::is_none")] @@ -2097,8 +2281,8 @@ impl GenerationRequest for EmbeddingRequest { fn extract_text_for_routing(&self) -> String { // Best effort: extract text content for routing decisions match &self.input { - serde_json::Value::String(s) => s.clone(), - serde_json::Value::Array(arr) => arr + Value::String(s) => s.clone(), + Value::Array(arr) => arr .iter() .filter_map(|v| v.as_str()) .collect::>() @@ -2173,7 +2357,7 @@ pub enum LoRAPath { #[cfg(test)] mod tests { use super::*; - use serde_json; + use serde_json::{from_str, json, to_string}; // ================================================================== // = RERANK REQUEST TESTS = @@ -2191,8 +2375,8 @@ mod tests { user: Some("user-456".to_string()), }; - let serialized = serde_json::to_string(&request).unwrap(); - let deserialized: RerankRequest = serde_json::from_str(&serialized).unwrap(); + let serialized = to_string(&request).unwrap(); + let deserialized: RerankRequest = from_str(&serialized).unwrap(); assert_eq!(deserialized.query, request.query); assert_eq!(deserialized.documents, request.documents); @@ -2210,7 +2394,7 @@ mod tests { "documents": ["doc1", "doc2"] }"#; - let request: RerankRequest = serde_json::from_str(json).unwrap(); + let request: RerankRequest = from_str(json).unwrap(); assert_eq!(request.query, "test query"); assert_eq!(request.documents, vec!["doc1", "doc2"]); @@ -2402,8 +2586,8 @@ mod tests { Some(StringOrArray::String("req-123".to_string())), ); - let serialized = serde_json::to_string(&response).unwrap(); - let deserialized: RerankResponse = serde_json::from_str(&serialized).unwrap(); + let serialized = to_string(&response).unwrap(); + let deserialized: RerankResponse = from_str(&serialized).unwrap(); assert_eq!(deserialized.results.len(), response.results.len()); assert_eq!(deserialized.model, response.model); @@ -2539,13 +2723,13 @@ mod tests { ("confidence".to_string(), Value::String("high".to_string())), ( "processing_time".to_string(), - Value::Number(serde_json::Number::from(150)), + Value::Number(Number::from(150)), ), ])), }; - let serialized = serde_json::to_string(&result).unwrap(); - let deserialized: RerankResult = serde_json::from_str(&serialized).unwrap(); + let serialized = to_string(&result).unwrap(); + let deserialized: RerankResult = from_str(&serialized).unwrap(); assert_eq!(deserialized.score, result.score); assert_eq!(deserialized.document, result.document); @@ -2562,8 +2746,8 @@ mod tests { meta_info: None, }; - let serialized = serde_json::to_string(&result).unwrap(); - let deserialized: RerankResult = serde_json::from_str(&serialized).unwrap(); + let serialized = to_string(&result).unwrap(); + let deserialized: RerankResult = from_str(&serialized).unwrap(); assert_eq!(deserialized.score, result.score); assert_eq!(deserialized.document, result.document); @@ -2582,8 +2766,8 @@ mod tests { documents: vec!["doc1".to_string(), "doc2".to_string()], }; - let serialized = serde_json::to_string(&v1_input).unwrap(); - let deserialized: V1RerankReqInput = serde_json::from_str(&serialized).unwrap(); + let serialized = to_string(&v1_input).unwrap(); + let deserialized: V1RerankReqInput = from_str(&serialized).unwrap(); assert_eq!(deserialized.query, v1_input.query); assert_eq!(deserialized.documents, v1_input.documents); @@ -2724,8 +2908,8 @@ mod tests { prompt_tokens_details: None, }); - let serialized = serde_json::to_string(&response).unwrap(); - let deserialized: RerankResponse = serde_json::from_str(&serialized).unwrap(); + let serialized = to_string(&response).unwrap(); + let deserialized: RerankResponse = from_str(&serialized).unwrap(); assert!(deserialized.usage.is_some()); let usage = deserialized.usage.unwrap(); @@ -2805,8 +2989,8 @@ mod tests { assert_eq!(response.model, "rerank-model"); // Serialize and deserialize - let serialized = serde_json::to_string(&response).unwrap(); - let deserialized: RerankResponse = serde_json::from_str(&serialized).unwrap(); + let serialized = to_string(&response).unwrap(); + let deserialized: RerankResponse = from_str(&serialized).unwrap(); assert_eq!(deserialized.results.len(), 2); assert_eq!(deserialized.model, response.model); } @@ -2819,15 +3003,15 @@ mod tests { fn test_embedding_request_serialization_string_input() { let req = EmbeddingRequest { model: "test-emb".to_string(), - input: serde_json::Value::String("hello".to_string()), + input: Value::String("hello".to_string()), encoding_format: Some("float".to_string()), user: Some("user-1".to_string()), dimensions: Some(128), rid: Some("rid-123".to_string()), }; - let serialized = serde_json::to_string(&req).unwrap(); - let deserialized: EmbeddingRequest = serde_json::from_str(&serialized).unwrap(); + let serialized = to_string(&req).unwrap(); + let deserialized: EmbeddingRequest = from_str(&serialized).unwrap(); assert_eq!(deserialized.model, req.model); assert_eq!(deserialized.input, req.input); @@ -2841,15 +3025,15 @@ mod tests { fn test_embedding_request_serialization_array_input() { let req = EmbeddingRequest { model: "test-emb".to_string(), - input: serde_json::json!(["a", "b", "c"]), + input: json!(["a", "b", "c"]), encoding_format: None, user: None, dimensions: None, rid: None, }; - let serialized = serde_json::to_string(&req).unwrap(); - let de: EmbeddingRequest = serde_json::from_str(&serialized).unwrap(); + let serialized = to_string(&req).unwrap(); + let de: EmbeddingRequest = from_str(&serialized).unwrap(); assert_eq!(de.model, req.model); assert_eq!(de.input, req.input); } @@ -2858,7 +3042,7 @@ mod tests { fn test_embedding_generation_request_trait_string() { let req = EmbeddingRequest { model: "emb-model".to_string(), - input: serde_json::Value::String("hello".to_string()), + input: Value::String("hello".to_string()), encoding_format: None, user: None, dimensions: None, @@ -2873,7 +3057,7 @@ mod tests { fn test_embedding_generation_request_trait_array() { let req = EmbeddingRequest { model: "emb-model".to_string(), - input: serde_json::json!(["hello", "world"]), + input: json!(["hello", "world"]), encoding_format: None, user: None, dimensions: None, @@ -2886,7 +3070,7 @@ mod tests { fn test_embedding_generation_request_trait_non_text() { let req = EmbeddingRequest { model: "emb-model".to_string(), - input: serde_json::json!({"tokens": [1, 2, 3]}), + input: json!({"tokens": [1, 2, 3]}), encoding_format: None, user: None, dimensions: None, @@ -2899,7 +3083,7 @@ mod tests { fn test_embedding_generation_request_trait_mixed_array_ignores_nested() { let req = EmbeddingRequest { model: "emb-model".to_string(), - input: serde_json::json!(["a", ["b", "c"], 123, {"k": "v"}]), + input: json!(["a", ["b", "c"], 123, {"k": "v"}]), encoding_format: None, user: None, dimensions: None, diff --git a/sgl-router/src/routers/factory.rs b/sgl-router/src/routers/factory.rs index a1a2cfd64..a5452a6cf 100644 --- a/sgl-router/src/routers/factory.rs +++ b/sgl-router/src/routers/factory.rs @@ -166,8 +166,12 @@ impl RouterFactory { .cloned() .ok_or_else(|| "OpenAI mode requires at least one worker URL".to_string())?; - let router = - OpenAIRouter::new(base_url, Some(ctx.router_config.circuit_breaker.clone())).await?; + let router = OpenAIRouter::new( + base_url, + Some(ctx.router_config.circuit_breaker.clone()), + ctx.response_storage.clone(), + ) + .await?; Ok(Box::new(router)) } diff --git a/sgl-router/src/routers/grpc/pd_router.rs b/sgl-router/src/routers/grpc/pd_router.rs index d8f9a6bce..b27967a9d 100644 --- a/sgl-router/src/routers/grpc/pd_router.rs +++ b/sgl-router/src/routers/grpc/pd_router.rs @@ -308,7 +308,12 @@ impl RouterTrait for GrpcPDRouter { (StatusCode::NOT_IMPLEMENTED).into_response() } - async fn get_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response { + async fn get_response( + &self, + _headers: Option<&HeaderMap>, + _response_id: &str, + _params: &crate::protocols::spec::ResponsesGetParams, + ) -> Response { (StatusCode::NOT_IMPLEMENTED).into_response() } diff --git a/sgl-router/src/routers/grpc/router.rs b/sgl-router/src/routers/grpc/router.rs index f91ce7694..4898fb451 100644 --- a/sgl-router/src/routers/grpc/router.rs +++ b/sgl-router/src/routers/grpc/router.rs @@ -237,7 +237,12 @@ impl RouterTrait for GrpcRouter { (StatusCode::NOT_IMPLEMENTED).into_response() } - async fn get_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response { + async fn get_response( + &self, + _headers: Option<&HeaderMap>, + _response_id: &str, + _params: &crate::protocols::spec::ResponsesGetParams, + ) -> Response { (StatusCode::NOT_IMPLEMENTED).into_response() } diff --git a/sgl-router/src/routers/header_utils.rs b/sgl-router/src/routers/header_utils.rs index 0adab5bf0..13b6f04ef 100644 --- a/sgl-router/src/routers/header_utils.rs +++ b/sgl-router/src/routers/header_utils.rs @@ -51,3 +51,45 @@ fn should_forward_header(name: &str) -> bool { "host" // Should not forward the backend's host header ) } + +/// Apply headers to a reqwest request builder, filtering out headers that shouldn't be forwarded +/// or that will be set automatically by reqwest +pub fn apply_request_headers( + headers: &HeaderMap, + mut request_builder: reqwest::RequestBuilder, + skip_content_headers: bool, +) -> reqwest::RequestBuilder { + // Always forward Authorization header first if present + if let Some(auth) = headers + .get("authorization") + .or_else(|| headers.get("Authorization")) + { + request_builder = request_builder.header("Authorization", auth.clone()); + } + + // Forward other headers, filtering out problematic ones + for (key, value) in headers.iter() { + let key_str = key.as_str().to_lowercase(); + + // Skip headers that: + // - Are set automatically by reqwest (content-type, content-length for POST/PUT) + // - We already handled (authorization) + // - Are hop-by-hop headers (connection, transfer-encoding) + // - Should not be forwarded (host) + let should_skip = key_str == "authorization" || // Already handled above + key_str == "host" || + key_str == "connection" || + key_str == "transfer-encoding" || + key_str == "keep-alive" || + key_str == "te" || + key_str == "trailers" || + key_str == "upgrade" || + (skip_content_headers && (key_str == "content-type" || key_str == "content-length")); + + if !should_skip { + request_builder = request_builder.header(key.clone(), value.clone()); + } + } + + request_builder +} diff --git a/sgl-router/src/routers/http/openai_router.rs b/sgl-router/src/routers/http/openai_router.rs index e259a5c39..2187ce898 100644 --- a/sgl-router/src/routers/http/openai_router.rs +++ b/sgl-router/src/routers/http/openai_router.rs @@ -1,10 +1,15 @@ -//! OpenAI router implementation (reqwest-based) +//! OpenAI router implementation use crate::config::CircuitBreakerConfig; use crate::core::{CircuitBreaker, CircuitBreakerConfig as CoreCircuitBreakerConfig}; +use crate::data_connector::{ResponseId, SharedResponseStorage, StoredResponse}; use crate::protocols::spec::{ - ChatCompletionRequest, CompletionRequest, GenerateRequest, RerankRequest, + ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, + ResponseContentPart, ResponseInput, ResponseInputOutputItem, ResponseOutputItem, + ResponseStatus, ResponseTextFormat, ResponsesGetParams, ResponsesRequest, ResponsesResponse, + TextFormatType, }; +use crate::routers::header_utils::{apply_request_headers, preserve_response_headers}; use async_trait::async_trait; use axum::{ body::Body, @@ -13,13 +18,17 @@ use axum::{ response::{IntoResponse, Response}, }; use futures_util::StreamExt; +use serde_json::{json, to_value, Value}; use std::{ any::Any, + collections::HashMap, sync::atomic::{AtomicBool, Ordering}, }; +use tokio::sync::mpsc; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tracing::{error, info, warn}; /// Router for OpenAI backend -#[derive(Debug)] pub struct OpenAIRouter { /// HTTP client for upstream OpenAI-compatible API client: reqwest::Client, @@ -29,6 +38,17 @@ pub struct OpenAIRouter { circuit_breaker: CircuitBreaker, /// Health status healthy: AtomicBool, + /// Response storage for managing conversation history + response_storage: SharedResponseStorage, +} + +impl std::fmt::Debug for OpenAIRouter { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("OpenAIRouter") + .field("base_url", &self.base_url) + .field("healthy", &self.healthy) + .finish() + } } impl OpenAIRouter { @@ -36,6 +56,7 @@ impl OpenAIRouter { pub async fn new( base_url: String, circuit_breaker_config: Option, + response_storage: SharedResponseStorage, ) -> Result { let client = reqwest::Client::builder() .timeout(std::time::Duration::from_secs(300)) @@ -61,8 +82,246 @@ impl OpenAIRouter { base_url, circuit_breaker, healthy: AtomicBool::new(true), + response_storage, }) } + + async fn handle_non_streaming_response( + &self, + url: String, + headers: Option<&HeaderMap>, + payload: Value, + original_body: &ResponsesRequest, + original_previous_response_id: Option, + ) -> Response { + let request_builder = self.client.post(&url).json(&payload); + + // Apply headers with filtering + let request_builder = if let Some(headers) = headers { + apply_request_headers(headers, request_builder, true) + } else { + request_builder + }; + + match request_builder.send().await { + Ok(response) => { + let status = response.status(); + + if !status.is_success() { + let error_text = response + .text() + .await + .unwrap_or_else(|e| format!("Failed to get error body: {}", e)); + return (status, error_text).into_response(); + } + + // Parse the response + match response.json::().await { + Ok(mut openai_response_json) => { + if let Some(prev_id) = original_previous_response_id { + if let Some(obj) = openai_response_json.as_object_mut() { + let should_insert = obj + .get("previous_response_id") + .map(|v| v.is_null()) + .unwrap_or(true); + if should_insert { + obj.insert( + "previous_response_id".to_string(), + Value::String(prev_id), + ); + } + } + } + + if let Some(obj) = openai_response_json.as_object_mut() { + if !obj.contains_key("instructions") { + if let Some(instructions) = &original_body.instructions { + obj.insert( + "instructions".to_string(), + Value::String(instructions.clone()), + ); + } + } + + if !obj.contains_key("metadata") { + if let Some(metadata) = &original_body.metadata { + let metadata_map: serde_json::Map = metadata + .iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect(); + obj.insert("metadata".to_string(), Value::Object(metadata_map)); + } + } + + // Reflect the client's requested store preference in the response body + obj.insert("store".to_string(), Value::Bool(original_body.store)); + } + + if original_body.store { + if let Err(e) = self + .store_response_internal(&openai_response_json, original_body) + .await + { + warn!("Failed to store response: {}", e); + } + } + + match serde_json::to_string(&openai_response_json) { + Ok(json_str) => ( + StatusCode::OK, + [("content-type", "application/json")], + json_str, + ) + .into_response(), + Err(e) => { + error!("Failed to serialize response: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + json!({"error": {"message": "Failed to serialize response", "type": "internal_error"}}).to_string(), + ) + .into_response() + } + } + } + Err(e) => { + error!("Failed to parse OpenAI response: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Failed to parse response: {}", e), + ) + .into_response() + } + } + } + Err(e) => ( + StatusCode::BAD_GATEWAY, + format!("Failed to forward request to OpenAI: {}", e), + ) + .into_response(), + } + } + + async fn handle_streaming_response( + &self, + _url: String, + _headers: Option<&HeaderMap>, + _payload: Value, + _original_body: &ResponsesRequest, + _original_previous_response_id: Option, + ) -> Response { + ( + StatusCode::NOT_IMPLEMENTED, + "Streaming responses not yet implemented", + ) + .into_response() + } + + async fn store_response_internal( + &self, + response_json: &Value, + original_body: &ResponsesRequest, + ) -> Result<(), String> { + if !original_body.store { + return Ok(()); + } + + match Self::store_response_impl(&self.response_storage, response_json, original_body).await + { + Ok(response_id) => { + info!(response_id = %response_id.0, "Stored response locally"); + Ok(()) + } + Err(e) => Err(e), + } + } + + async fn store_response_impl( + response_storage: &SharedResponseStorage, + response_json: &Value, + original_body: &ResponsesRequest, + ) -> Result { + let input_text = match &original_body.input { + ResponseInput::Text(text) => text.clone(), + ResponseInput::Items(_) => "complex input".to_string(), + }; + + let output_text = Self::extract_primary_output_text(response_json).unwrap_or_default(); + + let mut stored_response = StoredResponse::new(input_text, output_text, None); + + stored_response.instructions = response_json + .get("instructions") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + .or_else(|| original_body.instructions.clone()); + + stored_response.model = response_json + .get("model") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + .or_else(|| original_body.model.clone()); + + stored_response.user = response_json + .get("user") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + .or_else(|| original_body.user.clone()); + + stored_response.metadata = response_json + .get("metadata") + .and_then(|v| v.as_object()) + .map(|m| { + m.iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect::>() + }) + .unwrap_or_else(|| original_body.metadata.clone().unwrap_or_default()); + + stored_response.previous_response_id = response_json + .get("previous_response_id") + .and_then(|v| v.as_str()) + .map(|s| ResponseId::from_string(s.to_string())) + .or_else(|| { + original_body + .previous_response_id + .as_ref() + .map(|id| ResponseId::from_string(id.clone())) + }); + + if let Some(id_str) = response_json.get("id").and_then(|v| v.as_str()) { + stored_response.id = ResponseId::from_string(id_str.to_string()); + } + + stored_response.raw_response = response_json.clone(); + + response_storage + .store_response(stored_response) + .await + .map_err(|e| format!("Failed to store response: {}", e)) + } + + fn extract_primary_output_text(response_json: &Value) -> Option { + if let Some(items) = response_json.get("output").and_then(|v| v.as_array()) { + for item in items { + if let Some(content) = item.get("content").and_then(|v| v.as_array()) { + for part in content { + if part + .get("type") + .and_then(|v| v.as_str()) + .map(|t| t == "output_text") + .unwrap_or(false) + { + if let Some(text) = part.get("text").and_then(|v| v.as_str()) { + return Some(text.to_string()); + } + } + } + } + } + } + + None + } } #[async_trait] @@ -108,7 +367,7 @@ impl super::super::RouterTrait for OpenAIRouter { } async fn get_server_info(&self, _req: Request) -> Response { - let info = serde_json::json!({ + let info = json!({ "router_type": "openai", "workers": 1, "base_url": &self.base_url @@ -192,7 +451,7 @@ impl super::super::RouterTrait for OpenAIRouter { } // Serialize request body, removing SGLang-only fields - let mut payload = match serde_json::to_value(body) { + let mut payload = match to_value(body) { Ok(v) => v, Err(e) => { return ( @@ -282,7 +541,7 @@ impl super::super::RouterTrait for OpenAIRouter { } else { // Stream SSE bytes to client let stream = resp.bytes_stream(); - let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + let (tx, rx) = mpsc::unbounded_channel(); tokio::spawn(async move { let mut s = stream; while let Some(chunk) = s.next().await { @@ -299,9 +558,7 @@ impl super::super::RouterTrait for OpenAIRouter { } } }); - let mut response = Response::new(Body::from_stream( - tokio_stream::wrappers::UnboundedReceiverStream::new(rx), - )); + let mut response = Response::new(Body::from_stream(UnboundedReceiverStream::new(rx))); *response.status_mut() = status; response .headers_mut() @@ -326,36 +583,294 @@ impl super::super::RouterTrait for OpenAIRouter { async fn route_responses( &self, - _headers: Option<&HeaderMap>, - _body: &crate::protocols::spec::ResponsesRequest, - _model_id: Option<&str>, + headers: Option<&HeaderMap>, + body: &ResponsesRequest, + model_id: Option<&str>, ) -> Response { + let url = format!("{}/v1/responses", self.base_url); + + info!( + requested_store = body.store, + is_streaming = body.stream, + "openai_responses_request" + ); + + if body.stream { + return ( + StatusCode::NOT_IMPLEMENTED, + "Streaming responses not yet implemented", + ) + .into_response(); + } + + // Clone the body and override model if needed + let mut request_body = body.clone(); + if let Some(model) = model_id { + request_body.model = Some(model.to_string()); + } + + // Store the original previous_response_id for the response + let original_previous_response_id = request_body.previous_response_id.clone(); + + // Handle previous_response_id by loading prior context + let mut conversation_items: Option> = None; + if let Some(prev_id_str) = request_body.previous_response_id.clone() { + let prev_id = ResponseId::from_string(prev_id_str.clone()); + match self + .response_storage + .get_response_chain(&prev_id, None) + .await + { + Ok(chain) => { + if !chain.responses.is_empty() { + let mut items = Vec::new(); + for stored in chain.responses.iter() { + let trimmed_id = stored.id.0.trim_start_matches("resp_"); + if !stored.input.is_empty() { + items.push(ResponseInputOutputItem::Message { + id: format!("msg_u_{}", trimmed_id), + role: "user".to_string(), + status: Some("completed".to_string()), + content: vec![ResponseContentPart::InputText { + text: stored.input.clone(), + }], + }); + } + if !stored.output.is_empty() { + items.push(ResponseInputOutputItem::Message { + id: format!("msg_a_{}", trimmed_id), + role: "assistant".to_string(), + status: Some("completed".to_string()), + content: vec![ResponseContentPart::OutputText { + text: stored.output.clone(), + annotations: vec![], + logprobs: None, + }], + }); + } + } + conversation_items = Some(items); + } else { + info!(previous_response_id = %prev_id_str, "previous chain empty"); + } + } + Err(err) => { + warn!(previous_response_id = %prev_id_str, %err, "failed to fetch previous response chain"); + } + } + // Clear previous_response_id from request since we're converting to conversation + request_body.previous_response_id = None; + } + + if let Some(mut items) = conversation_items { + match &request_body.input { + ResponseInput::Text(text) => { + items.push(ResponseInputOutputItem::Message { + id: format!("msg_u_current_{}", items.len()), + role: "user".to_string(), + status: Some("completed".to_string()), + content: vec![ResponseContentPart::InputText { text: text.clone() }], + }); + } + ResponseInput::Items(existing) => { + items.extend(existing.clone()); + } + } + request_body.input = ResponseInput::Items(items); + } + + // Always set store=false for OpenAI (we store internally) + request_body.store = false; + + // Convert to JSON payload and strip SGLang-specific fields before forwarding + let mut payload = match to_value(&request_body) { + Ok(value) => value, + Err(err) => { + return ( + StatusCode::BAD_REQUEST, + format!("Failed to serialize responses request: {}", err), + ) + .into_response(); + } + }; + if let Some(obj) = payload.as_object_mut() { + for key in [ + "request_id", + "priority", + "frequency_penalty", + "presence_penalty", + "stop", + "top_k", + "min_p", + "repetition_penalty", + ] { + obj.remove(key); + } + } + + // Check if streaming is requested + if body.stream { + // Handle streaming response + self.handle_streaming_response( + url, + headers, + payload, + body, + original_previous_response_id, + ) + .await + } else { + // Handle non-streaming response + self.handle_non_streaming_response( + url, + headers, + payload, + body, + original_previous_response_id, + ) + .await + } + } + + async fn get_response( + &self, + _headers: Option<&HeaderMap>, + response_id: &str, + params: &ResponsesGetParams, + ) -> Response { + let stored_id = ResponseId::from_string(response_id.to_string()); + if let Ok(Some(stored_response)) = self.response_storage.get_response(&stored_id).await { + let stream_requested = params.stream.unwrap_or(false); + let raw_value = stored_response.raw_response.clone(); + + if !raw_value.is_null() { + if stream_requested { + return ( + StatusCode::NOT_IMPLEMENTED, + "Streaming retrieval not yet implemented", + ) + .into_response(); + } + + return ( + StatusCode::OK, + [("content-type", "application/json")], + raw_value.to_string(), + ) + .into_response(); + } + + let openai_response = ResponsesResponse { + id: stored_response.id.0.clone(), + object: "response".to_string(), + created_at: stored_response.created_at.timestamp(), + status: ResponseStatus::Completed, + error: None, + incomplete_details: None, + instructions: stored_response.instructions.clone(), + max_output_tokens: None, + model: stored_response + .model + .unwrap_or_else(|| "gpt-4o".to_string()), + output: vec![ResponseOutputItem::Message { + id: format!("msg_{}", stored_response.id.0), + role: "assistant".to_string(), + status: "completed".to_string(), + content: vec![ResponseContentPart::OutputText { + text: stored_response.output, + annotations: vec![], + logprobs: None, + }], + }], + parallel_tool_calls: true, + previous_response_id: stored_response.previous_response_id.map(|id| id.0), + reasoning: None, + store: true, + temperature: Some(1.0), + text: Some(ResponseTextFormat { + format: TextFormatType { + format_type: "text".to_string(), + }, + }), + tool_choice: "auto".to_string(), + tools: vec![], + top_p: Some(1.0), + truncation: Some("disabled".to_string()), + usage: None, + user: stored_response.user.clone(), + metadata: stored_response.metadata.clone(), + }; + + if stream_requested { + return ( + StatusCode::NOT_IMPLEMENTED, + "Streaming retrieval not yet implemented", + ) + .into_response(); + } + + return ( + StatusCode::OK, + [("content-type", "application/json")], + serde_json::to_string(&openai_response).unwrap_or_else(|e| { + format!("{{\"error\": \"Failed to serialize response: {}\"}}", e) + }), + ) + .into_response(); + } + ( - StatusCode::NOT_IMPLEMENTED, - "Responses endpoint not implemented for OpenAI router", + StatusCode::NOT_FOUND, + format!( + "Response with id '{}' not found in local storage", + response_id + ), ) .into_response() } - async fn get_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response { - ( - StatusCode::NOT_IMPLEMENTED, - "Responses retrieve endpoint not implemented for OpenAI router", - ) - .into_response() - } + async fn cancel_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response { + // Forward to OpenAI's cancel endpoint + let url = format!("{}/v1/responses/{}/cancel", self.base_url, response_id); - async fn cancel_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response { - ( - StatusCode::NOT_IMPLEMENTED, - "Responses cancel endpoint not implemented for OpenAI router", - ) - .into_response() + let request_builder = self.client.post(&url); + + // Apply headers with filtering (skip content headers for POST without body) + let request_builder = if let Some(headers) = headers { + apply_request_headers(headers, request_builder, true) + } else { + request_builder + }; + + match request_builder.send().await { + Ok(response) => { + let status = response.status(); + let headers = response.headers().clone(); + + match response.text().await { + Ok(body_text) => { + let mut response = (status, body_text).into_response(); + *response.headers_mut() = preserve_response_headers(&headers); + response + } + Err(e) => ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Failed to read response body: {}", e), + ) + .into_response(), + } + } + Err(e) => ( + StatusCode::BAD_GATEWAY, + format!("Failed to cancel response on OpenAI: {}", e), + ) + .into_response(), + } } async fn flush_cache(&self) -> Response { ( - StatusCode::NOT_IMPLEMENTED, + StatusCode::FORBIDDEN, "flush_cache not supported for OpenAI router", ) .into_response() @@ -363,7 +878,7 @@ impl super::super::RouterTrait for OpenAIRouter { async fn get_worker_loads(&self) -> Response { ( - StatusCode::NOT_IMPLEMENTED, + StatusCode::FORBIDDEN, "get_worker_loads not supported for OpenAI router", ) .into_response() @@ -384,12 +899,12 @@ impl super::super::RouterTrait for OpenAIRouter { async fn route_embeddings( &self, _headers: Option<&HeaderMap>, - _body: &crate::protocols::spec::EmbeddingRequest, + _body: &EmbeddingRequest, _model_id: Option<&str>, ) -> Response { ( - StatusCode::NOT_IMPLEMENTED, - "Embeddings endpoint not implemented for OpenAI backend", + StatusCode::FORBIDDEN, + "Embeddings endpoint not supported for OpenAI backend", ) .into_response() } @@ -401,8 +916,8 @@ impl super::super::RouterTrait for OpenAIRouter { _model_id: Option<&str>, ) -> Response { ( - StatusCode::NOT_IMPLEMENTED, - "Rerank endpoint not implemented for OpenAI backend", + StatusCode::FORBIDDEN, + "Rerank endpoint not supported for OpenAI backend", ) .into_response() } diff --git a/sgl-router/src/routers/http/pd_router.rs b/sgl-router/src/routers/http/pd_router.rs index 23ff7ab57..356f7fcd4 100644 --- a/sgl-router/src/routers/http/pd_router.rs +++ b/sgl-router/src/routers/http/pd_router.rs @@ -8,7 +8,7 @@ use crate::metrics::RouterMetrics; use crate::policies::{LoadBalancingPolicy, PolicyRegistry}; use crate::protocols::spec::{ ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, RerankRequest, - ResponsesRequest, StringOrArray, UserMessageContent, + ResponsesGetParams, ResponsesRequest, StringOrArray, UserMessageContent, }; use crate::routers::header_utils; use crate::routers::RouterTrait; @@ -1424,7 +1424,12 @@ impl RouterTrait for PDRouter { .into_response() } - async fn get_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response { + async fn get_response( + &self, + _headers: Option<&HeaderMap>, + _response_id: &str, + _params: &ResponsesGetParams, + ) -> Response { ( StatusCode::NOT_IMPLEMENTED, "Responses retrieve endpoint not implemented for PD router", diff --git a/sgl-router/src/routers/http/router.rs b/sgl-router/src/routers/http/router.rs index 8523b9d4e..28d701ef2 100644 --- a/sgl-router/src/routers/http/router.rs +++ b/sgl-router/src/routers/http/router.rs @@ -6,7 +6,7 @@ use crate::metrics::RouterMetrics; use crate::policies::{LoadBalancingPolicy, PolicyRegistry}; use crate::protocols::spec::{ ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, GenerationRequest, - RerankRequest, RerankResponse, RerankResult, ResponsesRequest, + RerankRequest, RerankResponse, RerankResult, ResponsesGetParams, ResponsesRequest, }; use crate::routers::header_utils; use crate::routers::RouterTrait; @@ -903,7 +903,12 @@ impl RouterTrait for Router { .await } - async fn get_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response { + async fn get_response( + &self, + headers: Option<&HeaderMap>, + response_id: &str, + _params: &ResponsesGetParams, + ) -> Response { let endpoint = format!("v1/responses/{}", response_id); self.route_get_request(headers, &endpoint).await } diff --git a/sgl-router/src/routers/mod.rs b/sgl-router/src/routers/mod.rs index 35d91fcfc..05cbdf743 100644 --- a/sgl-router/src/routers/mod.rs +++ b/sgl-router/src/routers/mod.rs @@ -11,7 +11,7 @@ use std::fmt::Debug; use crate::protocols::spec::{ ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, - ResponsesRequest, + ResponsesGetParams, ResponsesRequest, }; pub mod factory; @@ -82,7 +82,12 @@ pub trait RouterTrait: Send + Sync + Debug { ) -> Response; /// Retrieve a stored/background response by id - async fn get_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response; + async fn get_response( + &self, + headers: Option<&HeaderMap>, + response_id: &str, + params: &ResponsesGetParams, + ) -> Response; /// Cancel a background response by id async fn cancel_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response; diff --git a/sgl-router/src/routers/router_manager.rs b/sgl-router/src/routers/router_manager.rs index 27bd0299a..6bd01996c 100644 --- a/sgl-router/src/routers/router_manager.rs +++ b/sgl-router/src/routers/router_manager.rs @@ -8,7 +8,7 @@ use crate::config::{ConnectionMode, RoutingMode}; use crate::core::{WorkerRegistry, WorkerType}; use crate::protocols::spec::{ ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, - ResponsesRequest, + ResponsesGetParams, ResponsesRequest, }; use crate::routers::RouterTrait; use crate::server::{AppContext, ServerConfig}; @@ -403,38 +403,19 @@ impl RouterTrait for RouterManager { async fn route_responses( &self, - _headers: Option<&HeaderMap>, - _body: &ResponsesRequest, - _model_id: Option<&str>, + headers: Option<&HeaderMap>, + body: &ResponsesRequest, + model_id: Option<&str>, ) -> Response { - ( - StatusCode::NOT_IMPLEMENTED, - "responses api not yet implemented in inference gateway mode", - ) - .into_response() - } + let selected_model = body.model.as_deref().or(model_id); + let router = self.select_router_for_request(headers, selected_model); - async fn get_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response { - let router = self.select_router_for_request(headers, None); if let Some(router) = router { - router.get_response(headers, response_id).await + router.route_responses(headers, body, selected_model).await } else { ( StatusCode::NOT_FOUND, - format!("No router available to get response '{}'", response_id), - ) - .into_response() - } - } - - async fn cancel_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response { - let router = self.select_router_for_request(headers, None); - if let Some(router) = router { - router.cancel_response(headers, response_id).await - } else { - ( - StatusCode::NOT_FOUND, - format!("No router available to cancel response '{}'", response_id), + "No router available to handle responses request", ) .into_response() } @@ -460,6 +441,37 @@ impl RouterTrait for RouterManager { .into_response() } + async fn get_response( + &self, + headers: Option<&HeaderMap>, + response_id: &str, + params: &ResponsesGetParams, + ) -> Response { + let router = self.select_router_for_request(headers, None); + if let Some(router) = router { + router.get_response(headers, response_id, params).await + } else { + ( + StatusCode::NOT_FOUND, + format!("No router available to get response '{}'", response_id), + ) + .into_response() + } + } + + async fn cancel_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response { + let router = self.select_router_for_request(headers, None); + if let Some(router) = router { + router.cancel_response(headers, response_id).await + } else { + ( + StatusCode::NOT_FOUND, + format!("No router available to cancel response '{}'", response_id), + ) + .into_response() + } + } + async fn route_embeddings( &self, headers: Option<&HeaderMap>, diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index cc60b8424..28f69eb5f 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -9,7 +9,7 @@ use crate::{ protocols::{ spec::{ ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, - RerankRequest, ResponsesRequest, V1RerankReqInput, + RerankRequest, ResponsesGetParams, ResponsesRequest, V1RerankReqInput, }, worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse}, }, @@ -224,10 +224,11 @@ async fn v1_responses_get( State(state): State>, Path(response_id): Path, headers: http::HeaderMap, + Query(params): Query, ) -> Response { state .router - .get_response(Some(&headers), &response_id) + .get_response(Some(&headers), &response_id, ¶ms) .await } diff --git a/sgl-router/tests/test_openai_routing.rs b/sgl-router/tests/test_openai_routing.rs index cfa12389f..624ab9080 100644 --- a/sgl-router/tests/test_openai_routing.rs +++ b/sgl-router/tests/test_openai_routing.rs @@ -5,17 +5,23 @@ use axum::{ extract::Request, http::{Method, StatusCode}, routing::post, - Router, + Json, Router, }; use serde_json::json; use sglang_router_rs::{ config::{RouterConfig, RoutingMode}, + data_connector::{MemoryResponseStorage, ResponseId, ResponseStorage}, protocols::spec::{ - ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, UserMessageContent, + ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, ResponseInput, + ResponsesGetParams, ResponsesRequest, UserMessageContent, }, routers::{openai_router::OpenAIRouter, RouterTrait}, }; -use std::sync::Arc; +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, +}; +use tokio::net::TcpListener; use tower::ServiceExt; mod common; @@ -78,7 +84,12 @@ fn create_minimal_completion_request() -> CompletionRequest { /// Test basic OpenAI router creation and configuration #[tokio::test] async fn test_openai_router_creation() { - let router = OpenAIRouter::new("https://api.openai.com".to_string(), None).await; + let router = OpenAIRouter::new( + "https://api.openai.com".to_string(), + None, + Arc::new(MemoryResponseStorage::new()), + ) + .await; assert!(router.is_ok(), "Router creation should succeed"); @@ -90,9 +101,13 @@ async fn test_openai_router_creation() { /// Test health endpoints #[tokio::test] async fn test_openai_router_health() { - let router = OpenAIRouter::new("https://api.openai.com".to_string(), None) - .await - .unwrap(); + let router = OpenAIRouter::new( + "https://api.openai.com".to_string(), + None, + Arc::new(MemoryResponseStorage::new()), + ) + .await + .unwrap(); let req = Request::builder() .method(Method::GET) @@ -107,9 +122,13 @@ async fn test_openai_router_health() { /// Test server info endpoint #[tokio::test] async fn test_openai_router_server_info() { - let router = OpenAIRouter::new("https://api.openai.com".to_string(), None) - .await - .unwrap(); + let router = OpenAIRouter::new( + "https://api.openai.com".to_string(), + None, + Arc::new(MemoryResponseStorage::new()), + ) + .await + .unwrap(); let req = Request::builder() .method(Method::GET) @@ -132,9 +151,13 @@ async fn test_openai_router_server_info() { async fn test_openai_router_models() { // Use mock server for deterministic models response let mock_server = MockOpenAIServer::new().await; - let router = OpenAIRouter::new(mock_server.base_url(), None) - .await - .unwrap(); + let router = OpenAIRouter::new( + mock_server.base_url(), + None, + Arc::new(MemoryResponseStorage::new()), + ) + .await + .unwrap(); let req = Request::builder() .method(Method::GET) @@ -154,6 +177,138 @@ async fn test_openai_router_models() { assert!(models["data"].is_array()); } +#[tokio::test] +async fn test_openai_router_responses_with_mock() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let counter = Arc::new(AtomicUsize::new(0)); + let counter_clone = counter.clone(); + + let app = Router::new().route( + "/v1/responses", + post({ + move |Json(request): Json| { + let counter = counter_clone.clone(); + async move { + let idx = counter.fetch_add(1, Ordering::SeqCst) + 1; + let model = request + .get("model") + .and_then(|v| v.as_str()) + .unwrap_or("gpt-4o-mini") + .to_string(); + let id = format!("resp_mock_{idx}"); + let response = json!({ + "id": id, + "object": "response", + "created_at": 1_700_000_000 + idx as i64, + "status": "completed", + "model": model, + "output": [{ + "type": "message", + "id": format!("msg_{idx}"), + "role": "assistant", + "status": "completed", + "content": [{ + "type": "output_text", + "text": format!("mock_output_{idx}"), + "annotations": [] + }] + }], + "metadata": {} + }); + Json(response) + } + } + }), + ); + + let server = tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + let base_url = format!("http://{}", addr); + let storage = Arc::new(MemoryResponseStorage::new()); + + let router = OpenAIRouter::new(base_url, None, storage.clone()) + .await + .unwrap(); + + let request1 = ResponsesRequest { + model: Some("gpt-4o-mini".to_string()), + input: ResponseInput::Text("Say hi".to_string()), + store: true, + ..Default::default() + }; + + let response1 = router.route_responses(None, &request1, None).await; + assert_eq!(response1.status(), StatusCode::OK); + let body1_bytes = axum::body::to_bytes(response1.into_body(), usize::MAX) + .await + .unwrap(); + let body1: serde_json::Value = serde_json::from_slice(&body1_bytes).unwrap(); + let resp1_id = body1["id"].as_str().expect("id missing").to_string(); + assert_eq!(body1["previous_response_id"], serde_json::Value::Null); + + let request2 = ResponsesRequest { + model: Some("gpt-4o-mini".to_string()), + input: ResponseInput::Text("Thanks".to_string()), + store: true, + previous_response_id: Some(resp1_id.clone()), + ..Default::default() + }; + + let response2 = router.route_responses(None, &request2, None).await; + assert_eq!(response2.status(), StatusCode::OK); + let body2_bytes = axum::body::to_bytes(response2.into_body(), usize::MAX) + .await + .unwrap(); + let body2: serde_json::Value = serde_json::from_slice(&body2_bytes).unwrap(); + let resp2_id = body2["id"].as_str().expect("second id missing"); + assert_eq!( + body2["previous_response_id"].as_str(), + Some(resp1_id.as_str()) + ); + + let stored1 = storage + .get_response(&ResponseId::from_string(resp1_id.clone())) + .await + .unwrap() + .expect("first response missing"); + assert_eq!(stored1.input, "Say hi"); + assert_eq!(stored1.output, "mock_output_1"); + assert!(stored1.previous_response_id.is_none()); + + let stored2 = storage + .get_response(&ResponseId::from_string(resp2_id.to_string())) + .await + .unwrap() + .expect("second response missing"); + assert_eq!(stored2.previous_response_id.unwrap().0, resp1_id); + assert_eq!(stored2.output, "mock_output_2"); + + let get1 = router + .get_response(None, &stored1.id.0, &ResponsesGetParams::default()) + .await; + assert_eq!(get1.status(), StatusCode::OK); + let get1_body_bytes = axum::body::to_bytes(get1.into_body(), usize::MAX) + .await + .unwrap(); + let get1_json: serde_json::Value = serde_json::from_slice(&get1_body_bytes).unwrap(); + assert_eq!(get1_json, body1); + + let get2 = router + .get_response(None, &stored2.id.0, &ResponsesGetParams::default()) + .await; + assert_eq!(get2.status(), StatusCode::OK); + let get2_body_bytes = axum::body::to_bytes(get2.into_body(), usize::MAX) + .await + .unwrap(); + let get2_json: serde_json::Value = serde_json::from_slice(&get2_body_bytes).unwrap(); + assert_eq!(get2_json, body2); + + server.abort(); +} + /// Test router factory with OpenAI routing mode #[tokio::test] async fn test_router_factory_openai_mode() { @@ -179,9 +334,13 @@ async fn test_router_factory_openai_mode() { /// Test that unsupported endpoints return proper error codes #[tokio::test] async fn test_unsupported_endpoints() { - let router = OpenAIRouter::new("https://api.openai.com".to_string(), None) - .await - .unwrap(); + let router = OpenAIRouter::new( + "https://api.openai.com".to_string(), + None, + Arc::new(MemoryResponseStorage::new()), + ) + .await + .unwrap(); // Test generate endpoint (SGLang-specific, should not be supported) let generate_request = GenerateRequest { @@ -219,7 +378,9 @@ async fn test_openai_router_chat_completion_with_mock() { let base_url = mock_server.base_url(); // Create router pointing to mock server - let router = OpenAIRouter::new(base_url, None).await.unwrap(); + let router = OpenAIRouter::new(base_url, None, Arc::new(MemoryResponseStorage::new())) + .await + .unwrap(); // Create a minimal chat completion request let mut chat_request = create_minimal_chat_request(); @@ -255,7 +416,9 @@ async fn test_openai_e2e_with_server() { let base_url = mock_server.base_url(); // Create router - let router = OpenAIRouter::new(base_url, None).await.unwrap(); + let router = OpenAIRouter::new(base_url, None, Arc::new(MemoryResponseStorage::new())) + .await + .unwrap(); // Create Axum app with chat completions endpoint let app = Router::new().route( @@ -319,7 +482,9 @@ async fn test_openai_e2e_with_server() { async fn test_openai_router_chat_streaming_with_mock() { let mock_server = MockOpenAIServer::new().await; let base_url = mock_server.base_url(); - let router = OpenAIRouter::new(base_url, None).await.unwrap(); + let router = OpenAIRouter::new(base_url, None, Arc::new(MemoryResponseStorage::new())) + .await + .unwrap(); // Build a streaming chat request let val = json!({ @@ -368,6 +533,7 @@ async fn test_openai_router_circuit_breaker() { let router = OpenAIRouter::new( "http://invalid-url-that-will-fail".to_string(), Some(cb_config), + Arc::new(MemoryResponseStorage::new()), ) .await .unwrap(); @@ -391,9 +557,13 @@ async fn test_openai_router_models_auth_forwarding() { // Start a mock server that requires Authorization let expected_auth = "Bearer test-token".to_string(); let mock_server = MockOpenAIServer::new_with_auth(Some(expected_auth.clone())).await; - let router = OpenAIRouter::new(mock_server.base_url(), None) - .await - .unwrap(); + let router = OpenAIRouter::new( + mock_server.base_url(), + None, + Arc::new(MemoryResponseStorage::new()), + ) + .await + .unwrap(); // 1) Without auth header -> expect 401 let req = Request::builder()