diff --git a/sgl-router/src/protocols/spec.rs b/sgl-router/src/protocols/spec.rs index 10998b718..5a8f3b7d5 100644 --- a/sgl-router/src/protocols/spec.rs +++ b/sgl-router/src/protocols/spec.rs @@ -1073,8 +1073,8 @@ fn generate_request_id() -> String { #[derive(Debug, Clone, Deserialize, Serialize)] pub struct ResponsesRequest { /// Run the request in the background - #[serde(default)] - pub background: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub background: Option, /// Fields to include in the response #[serde(skip_serializing_if = "Option::is_none")] @@ -1108,8 +1108,8 @@ pub struct ResponsesRequest { pub conversation: Option, /// Whether to enable parallel tool calls - #[serde(default = "default_true")] - pub parallel_tool_calls: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub parallel_tool_calls: Option, /// ID of previous response to continue from #[serde(skip_serializing_if = "Option::is_none")] @@ -1120,40 +1120,40 @@ pub struct ResponsesRequest { pub reasoning: Option, /// Service tier - #[serde(default)] - pub service_tier: ServiceTier, + #[serde(skip_serializing_if = "Option::is_none")] + pub service_tier: Option, /// Whether to store the response - #[serde(default = "default_true")] - pub store: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub store: Option, /// Whether to stream the response - #[serde(default)] - pub stream: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub stream: Option, /// Temperature for sampling #[serde(skip_serializing_if = "Option::is_none")] pub temperature: Option, /// Tool choice behavior - #[serde(default)] - pub tool_choice: ToolChoice, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, /// Available tools - #[serde(default)] - pub tools: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, /// Number of top logprobs to return - #[serde(default)] - pub top_logprobs: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_logprobs: Option, /// Top-p sampling parameter #[serde(skip_serializing_if = "Option::is_none")] pub top_p: Option, /// Truncation behavior - #[serde(default)] - pub truncation: Truncation, + #[serde(skip_serializing_if = "Option::is_none")] + pub truncation: Option, /// User identifier #[serde(skip_serializing_if = "Option::is_none")] @@ -1168,12 +1168,12 @@ pub struct ResponsesRequest { pub priority: i32, /// Frequency penalty - #[serde(default)] - pub frequency_penalty: f32, + #[serde(skip_serializing_if = "Option::is_none")] + pub frequency_penalty: Option, /// Presence penalty - #[serde(default)] - pub presence_penalty: f32, + #[serde(skip_serializing_if = "Option::is_none")] + pub presence_penalty: Option, /// Stop sequences #[serde(skip_serializing_if = "Option::is_none")] @@ -1210,7 +1210,7 @@ fn default_repetition_penalty() -> f32 { impl Default for ResponsesRequest { fn default() -> Self { Self { - background: false, + background: None, include: None, input: ResponseInput::Text(String::new()), instructions: None, @@ -1219,23 +1219,23 @@ impl Default for ResponsesRequest { metadata: None, model: None, conversation: None, - parallel_tool_calls: true, + parallel_tool_calls: None, previous_response_id: None, reasoning: None, - service_tier: ServiceTier::default(), - store: true, - stream: false, + service_tier: None, + store: None, + stream: None, temperature: None, - tool_choice: ToolChoice::default(), - tools: Vec::new(), - top_logprobs: 0, + tool_choice: None, + tools: None, + top_logprobs: None, top_p: None, - truncation: Truncation::default(), + truncation: None, user: None, request_id: generate_request_id(), priority: 0, - frequency_penalty: 0.0, - presence_penalty: 0.0, + frequency_penalty: None, + presence_penalty: None, stop: None, top_k: default_top_k(), min_p: 0.0, @@ -1299,14 +1299,18 @@ impl ResponsesRequest { "top_p".to_string(), Value::Number(Number::from_f64(top_p as f64).unwrap()), ); - params.insert( - "frequency_penalty".to_string(), - Value::Number(Number::from_f64(self.frequency_penalty as f64).unwrap()), - ); - params.insert( - "presence_penalty".to_string(), - Value::Number(Number::from_f64(self.presence_penalty as f64).unwrap()), - ); + if let Some(fp) = self.frequency_penalty { + params.insert( + "frequency_penalty".to_string(), + Value::Number(Number::from_f64(fp as f64).unwrap()), + ); + } + if let Some(pp) = self.presence_penalty { + params.insert( + "presence_penalty".to_string(), + Value::Number(Number::from_f64(pp as f64).unwrap()), + ); + } params.insert("top_k".to_string(), Value::Number(Number::from(self.top_k))); params.insert( "min_p".to_string(), @@ -1337,7 +1341,7 @@ impl ResponsesRequest { impl GenerationRequest for ResponsesRequest { fn is_stream(&self) -> bool { - self.stream + self.stream.unwrap_or(false) } fn get_model(&self) -> Option<&str> { @@ -1523,13 +1527,13 @@ impl ResponsesResponse { max_output_tokens: request.max_output_tokens, model: model_name, output, - parallel_tool_calls: request.parallel_tool_calls, + parallel_tool_calls: request.parallel_tool_calls.unwrap_or(true), previous_response_id: request.previous_response_id.clone(), reasoning: request.reasoning.as_ref().map(|r| ReasoningInfo { effort: r.effort.as_ref().map(|e| format!("{:?}", e)), summary: None, }), - store: request.store, + store: request.store.unwrap_or(false), temperature: request.temperature, text: Some(ResponseTextFormat { format: TextFormatType { @@ -1537,17 +1541,19 @@ impl ResponsesResponse { }, }), tool_choice: match &request.tool_choice { - ToolChoice::Value(ToolChoiceValue::Auto) => "auto".to_string(), - ToolChoice::Value(ToolChoiceValue::Required) => "required".to_string(), - ToolChoice::Value(ToolChoiceValue::None) => "none".to_string(), - ToolChoice::Function { .. } => "function".to_string(), - ToolChoice::AllowedTools { mode, .. } => mode.clone(), + Some(ToolChoice::Value(ToolChoiceValue::Auto)) => "auto".to_string(), + Some(ToolChoice::Value(ToolChoiceValue::Required)) => "required".to_string(), + Some(ToolChoice::Value(ToolChoiceValue::None)) => "none".to_string(), + Some(ToolChoice::Function { .. }) => "function".to_string(), + Some(ToolChoice::AllowedTools { mode, .. }) => mode.clone(), + None => "auto".to_string(), }, - tools: request.tools.clone(), + tools: request.tools.clone().unwrap_or_default(), top_p: request.top_p, truncation: match &request.truncation { - Truncation::Auto => Some("auto".to_string()), - Truncation::Disabled => Some("disabled".to_string()), + Some(Truncation::Auto) => Some("auto".to_string()), + Some(Truncation::Disabled) => Some("disabled".to_string()), + None => None, }, usage: usage.map(ResponsesUsage::Classic), user: request.user.clone(), diff --git a/sgl-router/src/routers/openai/mcp.rs b/sgl-router/src/routers/openai/mcp.rs index d50fab6d0..d23ca396a 100644 --- a/sgl-router/src/routers/openai/mcp.rs +++ b/sgl-router/src/routers/openai/mcp.rs @@ -689,9 +689,13 @@ pub(super) async fn execute_tool_loop( if state.total_calls > 0 { let server_label = original_body .tools - .iter() - .find(|t| matches!(t.r#type, ResponseToolType::Mcp)) - .and_then(|t| t.server_label.as_deref()) + .as_ref() + .and_then(|tools| { + tools + .iter() + .find(|t| matches!(t.r#type, ResponseToolType::Mcp)) + .and_then(|t| t.server_label.as_deref()) + }) .unwrap_or("mcp"); // Build mcp_list_tools item @@ -747,9 +751,13 @@ pub(super) fn build_incomplete_response( if let Some(output_array) = obj.get_mut("output").and_then(|v| v.as_array_mut()) { let server_label = original_body .tools - .iter() - .find(|t| matches!(t.r#type, ResponseToolType::Mcp)) - .and_then(|t| t.server_label.as_deref()) + .as_ref() + .and_then(|tools| { + tools + .iter() + .find(|t| matches!(t.r#type, ResponseToolType::Mcp)) + .and_then(|t| t.server_label.as_deref()) + }) .unwrap_or("mcp"); // Find any function_call items and convert them to mcp_call (incomplete) diff --git a/sgl-router/src/routers/openai/responses.rs b/sgl-router/src/routers/openai/responses.rs index 58ca966b5..3c5a73d28 100644 --- a/sgl-router/src/routers/openai/responses.rs +++ b/sgl-router/src/routers/openai/responses.rs @@ -129,7 +129,10 @@ pub(super) fn patch_streaming_response_json( } } - obj.insert("store".to_string(), Value::Bool(original_body.store)); + obj.insert( + "store".to_string(), + Value::Bool(original_body.store.unwrap_or(false)), + ); if obj .get("model") @@ -205,7 +208,7 @@ pub(super) fn rewrite_streaming_block( let mut changed = false; if let Some(response_obj) = parsed.get_mut("response").and_then(|v| v.as_object_mut()) { - let desired_store = Value::Bool(original_body.store); + let desired_store = Value::Bool(original_body.store.unwrap_or(false)); if response_obj.get("store") != Some(&desired_store) { response_obj.insert("store".to_string(), desired_store); changed = true; @@ -267,10 +270,11 @@ pub(super) fn rewrite_streaming_block( /// Mask function tools as MCP tools in response for client pub(super) fn mask_tools_as_mcp(resp: &mut Value, original_body: &ResponsesRequest) { - let mcp_tool = original_body - .tools - .iter() - .find(|t| matches!(t.r#type, ResponseToolType::Mcp) && t.server_url.is_some()); + let mcp_tool = original_body.tools.as_ref().and_then(|tools| { + tools + .iter() + .find(|t| matches!(t.r#type, ResponseToolType::Mcp) && t.server_url.is_some()) + }); let Some(t) = mcp_tool else { return; }; diff --git a/sgl-router/src/routers/openai/router.rs b/sgl-router/src/routers/openai/router.rs index 11ce66fd2..607a94dd3 100644 --- a/sgl-router/src/routers/openai/router.rs +++ b/sgl-router/src/routers/openai/router.rs @@ -148,7 +148,11 @@ impl OpenAIRouter { original_previous_response_id: Option, ) -> Response { // Check if MCP is active for this request - let req_mcp_manager = mcp_manager_from_request_tools(&original_body.tools).await; + let req_mcp_manager = if let Some(ref tools) = original_body.tools { + mcp_manager_from_request_tools(tools.as_slice()).await + } else { + None + }; let active_mcp = req_mcp_manager.as_ref().or(self.mcp_manager.as_ref()); let mut response_json: Value; @@ -183,6 +187,7 @@ impl OpenAIRouter { } } else { // No MCP - simple request + let mut request_builder = self.client.post(&url).json(&payload); if let Some(h) = headers { request_builder = apply_request_headers(h, request_builder, true); @@ -385,6 +390,7 @@ impl crate::routers::RouterTrait for OpenAIRouter { } }; if let Some(obj) = payload.as_object_mut() { + // Always remove SGLang-specific fields (unsupported by OpenAI) for key in [ "top_k", "min_p", @@ -535,7 +541,7 @@ impl crate::routers::RouterTrait for OpenAIRouter { .into_response(); } - // Clone the body and override model if needed + // Clone the body for validation and logic, but we'll build payload differently let mut request_body = body.clone(); if let Some(model) = model_id { request_body.model = Some(model.to_string()); @@ -690,7 +696,7 @@ impl crate::routers::RouterTrait for OpenAIRouter { } // Always set store=false for upstream (we store internally) - request_body.store = false; + request_body.store = Some(false); // Convert to JSON and strip SGLang-specific fields let mut payload = match to_value(&request_body) { @@ -704,14 +710,13 @@ impl crate::routers::RouterTrait for OpenAIRouter { } }; - // Remove SGLang-specific fields + // Remove SGLang-specific fields only if let Some(obj) = payload.as_object_mut() { + // Remove SGLang-specific fields (not part of OpenAI API) for key in [ "request_id", "priority", "top_k", - "frequency_penalty", - "presence_penalty", "min_p", "min_tokens", "regex", @@ -732,10 +737,38 @@ impl crate::routers::RouterTrait for OpenAIRouter { ] { obj.remove(key); } + // XAI doesn't support the OPENAI item type input: https://platform.openai.com/docs/api-reference/responses/create#responses-create-input-input-item-list-item + // To Achieve XAI compatibility, strip extra fields from input messages (id, status) + // XAI doesn't support output_text as type for content with role of assistant + // so normalize content types: output_text -> input_text + if let Some(input_arr) = obj.get_mut("input").and_then(Value::as_array_mut) { + for item_obj in input_arr.iter_mut().filter_map(Value::as_object_mut) { + // Remove fields not universally supported + item_obj.remove("id"); + item_obj.remove("status"); + + // Normalize content types to input_text (xAI compatibility) + if let Some(content_arr) = + item_obj.get_mut("content").and_then(Value::as_array_mut) + { + for content_obj in content_arr.iter_mut().filter_map(Value::as_object_mut) { + // Change output_text to input_text + if content_obj.get("type").and_then(Value::as_str) + == Some("output_text") + { + content_obj.insert( + "type".to_string(), + Value::String("input_text".to_string()), + ); + } + } + } + } + } } // Delegate to streaming or non-streaming handler - if body.stream { + if body.stream.unwrap_or(false) { handle_streaming_response( &self.client, &self.circuit_breaker, diff --git a/sgl-router/src/routers/openai/streaming.rs b/sgl-router/src/routers/openai/streaming.rs index b643840d6..9a630ff82 100644 --- a/sgl-router/src/routers/openai/streaming.rs +++ b/sgl-router/src/routers/openai/streaming.rs @@ -572,7 +572,7 @@ pub(super) fn apply_event_transformations_inplace( .get_mut("response") .and_then(|v| v.as_object_mut()) { - let desired_store = Value::Bool(original_request.store); + let desired_store = Value::Bool(original_request.store.unwrap_or(false)); if response_obj.get("store") != Some(&desired_store) { response_obj.insert("store".to_string(), desired_store); changed = true; @@ -597,8 +597,13 @@ pub(super) fn apply_event_transformations_inplace( if response_obj.get("tools").is_some() { let requested_mcp = original_request .tools - .iter() - .any(|t| matches!(t.r#type, ResponseToolType::Mcp)); + .as_ref() + .map(|tools| { + tools + .iter() + .any(|t| matches!(t.r#type, ResponseToolType::Mcp)) + }) + .unwrap_or(false); if requested_mcp { if let Some(mcp_tools) = build_mcp_tools_value(original_request) { @@ -658,8 +663,8 @@ pub(super) fn apply_event_transformations_inplace( /// Helper to build MCP tools value fn build_mcp_tools_value(original_body: &ResponsesRequest) -> Option { - let mcp_tool = original_body - .tools + let tools = original_body.tools.as_ref()?; + let mcp_tool = tools .iter() .find(|t| matches!(t.r#type, ResponseToolType::Mcp) && t.server_url.is_some())?; @@ -1000,7 +1005,7 @@ pub(super) async fn handle_simple_streaming_passthrough( let (tx, rx) = mpsc::unbounded_channel::>(); - let should_store = original_body.store; + let should_store = original_body.store.unwrap_or(false); let original_request = original_body.clone(); let persist_needed = original_request.conversation.is_some(); let previous_response_id = original_previous_response_id.clone(); @@ -1134,7 +1139,7 @@ pub(super) async fn handle_streaming_with_tool_interception( prepare_mcp_payload_for_streaming(&mut payload, active_mcp); let (tx, rx) = mpsc::unbounded_channel::>(); - let should_store = original_body.store; + let should_store = original_body.store.unwrap_or(false); let original_request = original_body.clone(); let persist_needed = original_request.conversation.is_some(); let previous_response_id = original_previous_response_id.clone(); @@ -1161,9 +1166,13 @@ pub(super) async fn handle_streaming_with_tool_interception( let server_label = original_request .tools - .iter() - .find(|t| matches!(t.r#type, ResponseToolType::Mcp)) - .and_then(|t| t.server_label.as_deref()) + .as_ref() + .and_then(|tools| { + tools + .iter() + .find(|t| matches!(t.r#type, ResponseToolType::Mcp)) + .and_then(|t| t.server_label.as_deref()) + }) .unwrap_or("mcp"); loop { @@ -1488,7 +1497,11 @@ pub(super) async fn handle_streaming_response( original_previous_response_id: Option, ) -> Response { // Check if MCP is active for this request - let req_mcp_manager = mcp_manager_from_request_tools(&original_body.tools).await; + let req_mcp_manager = if let Some(ref tools) = original_body.tools { + mcp_manager_from_request_tools(tools.as_slice()).await + } else { + None + }; let active_mcp = req_mcp_manager.as_ref().or(mcp_manager); // If no MCP is active, use simple pass-through streaming diff --git a/sgl-router/tests/responses_api_test.rs b/sgl-router/tests/responses_api_test.rs index 4e640c1b3..c0239af46 100644 --- a/sgl-router/tests/responses_api_test.rs +++ b/sgl-router/tests/responses_api_test.rs @@ -89,7 +89,7 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() { // Build a simple ResponsesRequest that will trigger the tool call let req = ResponsesRequest { - background: false, + background: Some(false), include: None, input: ResponseInput::Text("search something".to_string()), instructions: Some("Be brief".to_string()), @@ -97,15 +97,15 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() { max_tool_calls: None, metadata: None, model: Some("mock-model".to_string()), - parallel_tool_calls: true, + parallel_tool_calls: Some(true), previous_response_id: None, reasoning: None, - service_tier: ServiceTier::Auto, - store: true, - stream: false, + service_tier: Some(ServiceTier::Auto), + store: Some(true), + stream: Some(false), temperature: Some(0.2), - tool_choice: ToolChoice::default(), - tools: vec![ResponseTool { + tool_choice: Some(ToolChoice::default()), + tools: Some(vec![ResponseTool { r#type: ResponseToolType::Mcp, server_url: Some(mcp.url()), authorization: None, @@ -113,15 +113,15 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() { server_description: None, require_approval: None, allowed_tools: None, - }], - top_logprobs: 0, + }]), + top_logprobs: Some(0), top_p: None, - truncation: Truncation::Disabled, + truncation: Some(Truncation::Disabled), user: None, request_id: "resp_test_mcp_e2e".to_string(), priority: 0, - frequency_penalty: 0.0, - presence_penalty: 0.0, + frequency_penalty: Some(0.0), + presence_penalty: Some(0.0), stop: None, top_k: -1, min_p: 0.0, @@ -338,7 +338,7 @@ async fn test_conversations_crud_basic() { #[test] fn test_responses_request_creation() { let request = ResponsesRequest { - background: false, + background: Some(false), include: None, input: ResponseInput::Text("Hello, world!".to_string()), instructions: Some("Be helpful".to_string()), @@ -346,29 +346,29 @@ fn test_responses_request_creation() { max_tool_calls: None, metadata: None, model: Some("test-model".to_string()), - parallel_tool_calls: true, + parallel_tool_calls: Some(true), previous_response_id: None, reasoning: Some(ResponseReasoningParam { effort: Some(ReasoningEffort::Medium), summary: None, }), - service_tier: ServiceTier::Auto, - store: true, - stream: false, + service_tier: Some(ServiceTier::Auto), + store: Some(true), + stream: Some(false), temperature: Some(0.7), - tool_choice: ToolChoice::Value(ToolChoiceValue::Auto), - tools: vec![ResponseTool { + tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Auto)), + tools: Some(vec![ResponseTool { r#type: ResponseToolType::WebSearchPreview, ..Default::default() - }], - top_logprobs: 5, + }]), + top_logprobs: Some(5), top_p: Some(0.9), - truncation: Truncation::Disabled, + truncation: Some(Truncation::Disabled), user: Some("test-user".to_string()), request_id: "resp_test123".to_string(), priority: 0, - frequency_penalty: 0.0, - presence_penalty: 0.0, + frequency_penalty: Some(0.0), + presence_penalty: Some(0.0), stop: None, top_k: -1, min_p: 0.0, @@ -385,7 +385,7 @@ fn test_responses_request_creation() { #[test] fn test_sampling_params_conversion() { let request = ResponsesRequest { - background: false, + background: Some(false), include: None, input: ResponseInput::Text("Test".to_string()), instructions: None, @@ -393,23 +393,23 @@ fn test_sampling_params_conversion() { max_tool_calls: None, metadata: None, model: Some("test-model".to_string()), - parallel_tool_calls: true, // Use default true + parallel_tool_calls: Some(true), // Use default true previous_response_id: None, reasoning: None, - service_tier: ServiceTier::Auto, - store: true, // Use default true - stream: false, + service_tier: Some(ServiceTier::Auto), + store: Some(true), // Use default true + stream: Some(false), temperature: Some(0.8), - tool_choice: ToolChoice::Value(ToolChoiceValue::Auto), - tools: vec![], - top_logprobs: 0, // Use default 0 + tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Auto)), + tools: Some(vec![]), + top_logprobs: Some(0), // Use default 0 top_p: Some(0.95), - truncation: Truncation::Auto, + truncation: Some(Truncation::Auto), user: None, request_id: "resp_test456".to_string(), priority: 0, - frequency_penalty: 0.1, - presence_penalty: 0.2, + frequency_penalty: Some(0.1), + presence_penalty: Some(0.2), stop: None, top_k: 10, min_p: 0.05, @@ -493,7 +493,7 @@ fn test_reasoning_param_default() { #[test] fn test_json_serialization() { let request = ResponsesRequest { - background: true, + background: Some(true), include: None, input: ResponseInput::Text("Test input".to_string()), instructions: Some("Test instructions".to_string()), @@ -501,29 +501,29 @@ fn test_json_serialization() { max_tool_calls: Some(5), metadata: None, model: Some("gpt-4".to_string()), - parallel_tool_calls: false, + parallel_tool_calls: Some(false), previous_response_id: None, reasoning: Some(ResponseReasoningParam { effort: Some(ReasoningEffort::High), summary: None, }), - service_tier: ServiceTier::Priority, - store: false, - stream: true, + service_tier: Some(ServiceTier::Priority), + store: Some(false), + stream: Some(true), temperature: Some(0.9), - tool_choice: ToolChoice::Value(ToolChoiceValue::Required), - tools: vec![ResponseTool { + tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Required)), + tools: Some(vec![ResponseTool { r#type: ResponseToolType::CodeInterpreter, ..Default::default() - }], - top_logprobs: 10, + }]), + top_logprobs: Some(10), top_p: Some(0.8), - truncation: Truncation::Auto, + truncation: Some(Truncation::Auto), user: Some("test_user".to_string()), request_id: "resp_comprehensive_test".to_string(), priority: 1, - frequency_penalty: 0.3, - presence_penalty: 0.4, + frequency_penalty: Some(0.3), + presence_penalty: Some(0.4), stop: None, top_k: 50, min_p: 0.1, @@ -537,9 +537,9 @@ fn test_json_serialization() { assert_eq!(parsed.request_id, "resp_comprehensive_test"); assert_eq!(parsed.model, Some("gpt-4".to_string())); - assert!(parsed.background); - assert!(parsed.stream); - assert_eq!(parsed.tools.len(), 1); + assert_eq!(parsed.background, Some(true)); + assert_eq!(parsed.stream, Some(true)); + assert_eq!(parsed.tools.as_ref().map(|t| t.len()), Some(1)); } #[tokio::test] @@ -620,7 +620,7 @@ async fn test_multi_turn_loop_with_mcp() { // Build request with MCP tools let req = ResponsesRequest { - background: false, + background: Some(false), include: None, input: ResponseInput::Text("search for SGLang".to_string()), instructions: Some("Be helpful".to_string()), @@ -628,30 +628,30 @@ async fn test_multi_turn_loop_with_mcp() { max_tool_calls: None, // No limit - test unlimited metadata: None, model: Some("mock-model".to_string()), - parallel_tool_calls: true, + parallel_tool_calls: Some(true), previous_response_id: None, reasoning: None, - service_tier: ServiceTier::Auto, - store: true, - stream: false, + service_tier: Some(ServiceTier::Auto), + store: Some(true), + stream: Some(false), temperature: Some(0.7), - tool_choice: ToolChoice::Value(ToolChoiceValue::Auto), - tools: vec![ResponseTool { + tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Auto)), + tools: Some(vec![ResponseTool { r#type: ResponseToolType::Mcp, server_url: Some(mcp.url()), server_label: Some("mock".to_string()), server_description: Some("Mock MCP server for testing".to_string()), require_approval: Some("never".to_string()), ..Default::default() - }], - top_logprobs: 0, + }]), + top_logprobs: Some(0), top_p: Some(1.0), - truncation: Truncation::Disabled, + truncation: Some(Truncation::Disabled), user: None, request_id: "resp_multi_turn_test".to_string(), priority: 0, - frequency_penalty: 0.0, - presence_penalty: 0.0, + frequency_penalty: Some(0.0), + presence_penalty: Some(0.0), stop: None, top_k: 50, min_p: 0.0, @@ -796,7 +796,7 @@ async fn test_max_tool_calls_limit() { .expect("router"); let req = ResponsesRequest { - background: false, + background: Some(false), include: None, input: ResponseInput::Text("test max calls".to_string()), instructions: None, @@ -804,28 +804,28 @@ async fn test_max_tool_calls_limit() { max_tool_calls: Some(1), // Limit to 1 call metadata: None, model: Some("mock-model".to_string()), - parallel_tool_calls: true, + parallel_tool_calls: Some(true), previous_response_id: None, reasoning: None, - service_tier: ServiceTier::Auto, - store: false, - stream: false, + service_tier: Some(ServiceTier::Auto), + store: Some(false), + stream: Some(false), temperature: Some(0.7), - tool_choice: ToolChoice::Value(ToolChoiceValue::Auto), - tools: vec![ResponseTool { + tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Auto)), + tools: Some(vec![ResponseTool { r#type: ResponseToolType::Mcp, server_url: Some(mcp.url()), server_label: Some("mock".to_string()), ..Default::default() - }], - top_logprobs: 0, + }]), + top_logprobs: Some(0), top_p: Some(1.0), - truncation: Truncation::Disabled, + truncation: Some(Truncation::Disabled), user: None, request_id: "resp_max_calls_test".to_string(), priority: 0, - frequency_penalty: 0.0, - presence_penalty: 0.0, + frequency_penalty: Some(0.0), + presence_penalty: Some(0.0), stop: None, top_k: 50, min_p: 0.0, @@ -990,7 +990,7 @@ async fn test_streaming_with_mcp_tool_calls() { // Build streaming request with MCP tools let req = ResponsesRequest { - background: false, + background: Some(false), include: None, input: ResponseInput::Text("search for something interesting".to_string()), instructions: Some("Use tools when needed".to_string()), @@ -998,30 +998,30 @@ async fn test_streaming_with_mcp_tool_calls() { max_tool_calls: Some(3), metadata: None, model: Some("mock-model".to_string()), - parallel_tool_calls: true, + parallel_tool_calls: Some(true), previous_response_id: None, reasoning: None, - service_tier: ServiceTier::Auto, - store: true, - stream: true, // KEY: Enable streaming + service_tier: Some(ServiceTier::Auto), + store: Some(true), + stream: Some(true), // KEY: Enable streaming temperature: Some(0.7), - tool_choice: ToolChoice::Value(ToolChoiceValue::Auto), - tools: vec![ResponseTool { + tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Auto)), + tools: Some(vec![ResponseTool { r#type: ResponseToolType::Mcp, server_url: Some(mcp.url()), server_label: Some("mock".to_string()), server_description: Some("Mock MCP for streaming test".to_string()), require_approval: Some("never".to_string()), ..Default::default() - }], - top_logprobs: 0, + }]), + top_logprobs: Some(0), top_p: Some(1.0), - truncation: Truncation::Disabled, + truncation: Some(Truncation::Disabled), user: None, request_id: "resp_streaming_mcp_test".to_string(), priority: 0, - frequency_penalty: 0.0, - presence_penalty: 0.0, + frequency_penalty: Some(0.0), + presence_penalty: Some(0.0), stop: None, top_k: 50, min_p: 0.0, @@ -1271,7 +1271,7 @@ async fn test_streaming_multi_turn_with_mcp() { let (mut mcp, mut worker, router, _dir) = setup_streaming_mcp_test().await; let req = ResponsesRequest { - background: false, + background: Some(false), include: None, input: ResponseInput::Text("complex query requiring multiple tool calls".to_string()), instructions: Some("Be thorough".to_string()), @@ -1279,28 +1279,28 @@ async fn test_streaming_multi_turn_with_mcp() { max_tool_calls: Some(5), // Allow multiple rounds metadata: None, model: Some("mock-model".to_string()), - parallel_tool_calls: true, + parallel_tool_calls: Some(true), previous_response_id: None, reasoning: None, - service_tier: ServiceTier::Auto, - store: true, - stream: true, + service_tier: Some(ServiceTier::Auto), + store: Some(true), + stream: Some(true), temperature: Some(0.8), - tool_choice: ToolChoice::Value(ToolChoiceValue::Auto), - tools: vec![ResponseTool { + tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Auto)), + tools: Some(vec![ResponseTool { r#type: ResponseToolType::Mcp, server_url: Some(mcp.url()), server_label: Some("mock".to_string()), ..Default::default() - }], - top_logprobs: 0, + }]), + top_logprobs: Some(0), top_p: Some(1.0), - truncation: Truncation::Disabled, + truncation: Some(Truncation::Disabled), user: None, request_id: "resp_streaming_multiturn_test".to_string(), priority: 0, - frequency_penalty: 0.0, - presence_penalty: 0.0, + frequency_penalty: Some(0.0), + presence_penalty: Some(0.0), stop: None, top_k: 50, min_p: 0.0, diff --git a/sgl-router/tests/test_openai_routing.rs b/sgl-router/tests/test_openai_routing.rs index 3e288928e..b68a3f9bb 100644 --- a/sgl-router/tests/test_openai_routing.rs +++ b/sgl-router/tests/test_openai_routing.rs @@ -234,7 +234,7 @@ async fn test_openai_router_responses_with_mock() { let request1 = ResponsesRequest { model: Some("gpt-4o-mini".to_string()), input: ResponseInput::Text("Say hi".to_string()), - store: true, + store: Some(true), ..Default::default() }; @@ -250,7 +250,7 @@ async fn test_openai_router_responses_with_mock() { let request2 = ResponsesRequest { model: Some("gpt-4o-mini".to_string()), input: ResponseInput::Text("Thanks".to_string()), - store: true, + store: Some(true), previous_response_id: Some(resp1_id.clone()), ..Default::default() }; @@ -501,8 +501,8 @@ async fn test_openai_router_responses_streaming_with_mock() { instructions: Some("Be kind".to_string()), metadata: Some(metadata), previous_response_id: Some("resp_prev_chain".to_string()), - store: true, - stream: true, + store: Some(true), + stream: Some(true), ..Default::default() };