diff --git a/sgl-router/src/protocols/spec.rs b/sgl-router/src/protocols/spec.rs index e157ebd68..994f2c434 100644 --- a/sgl-router/src/protocols/spec.rs +++ b/sgl-router/src/protocols/spec.rs @@ -723,7 +723,10 @@ pub enum ResponseToolType { #[derive(Debug, Clone, Deserialize, Serialize)] pub struct ResponseReasoningParam { #[serde(default = "default_reasoning_effort")] + #[serde(skip_serializing_if = "Option::is_none")] pub effort: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub summary: Option, } fn default_reasoning_effort() -> Option { @@ -738,6 +741,14 @@ pub enum ReasoningEffort { High, } +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum ReasoningSummary { + Auto, + Concise, + Detailed, +} + #[derive(Debug, Clone, Deserialize, Serialize)] #[serde(tag = "type")] #[serde(rename_all = "snake_case")] diff --git a/sgl-router/src/routers/http/openai_router.rs b/sgl-router/src/routers/http/openai_router.rs index e9810ef49..2e5f85f52 100644 --- a/sgl-router/src/routers/http/openai_router.rs +++ b/sgl-router/src/routers/http/openai_router.rs @@ -26,7 +26,6 @@ use std::{ collections::HashMap, io, sync::{atomic::AtomicBool, Arc}, - time::SystemTime, }; use tokio::sync::mpsc; use tokio_stream::wrappers::UnboundedReceiverStream; @@ -57,6 +56,69 @@ impl std::fmt::Debug for OpenAIRouter { } } +/// Configuration for MCP tool calling loops +#[derive(Debug, Clone)] +struct McpLoopConfig { + /// Maximum iterations as safety limit (internal only, default: 10) + /// Prevents infinite loops when max_tool_calls is not set + max_iterations: usize, +} + +impl Default for McpLoopConfig { + fn default() -> Self { + Self { max_iterations: 10 } + } +} + +/// State for tracking multi-turn tool calling loop +struct ToolLoopState { + /// Current iteration number (starts at 0, increments with each tool call) + iteration: usize, + /// Total number of tool calls executed + total_calls: usize, + /// Conversation history (function_call and function_call_output items) + conversation_history: Vec, + /// Original user input (preserved for building resume payloads) + original_input: ResponseInput, +} + +impl ToolLoopState { + fn new(original_input: ResponseInput) -> Self { + Self { + iteration: 0, + total_calls: 0, + conversation_history: Vec::new(), + original_input, + } + } + + /// Record a tool call in the loop state + fn record_call( + &mut self, + call_id: String, + tool_name: String, + args_json_str: String, + output_str: String, + ) { + // Add function_call item to history + let func_item = json!({ + "type": "function_call", + "call_id": call_id, + "name": tool_name, + "arguments": args_json_str + }); + self.conversation_history.push(func_item); + + // Add function_call_output item to history + let output_item = json!({ + "type": "function_call_output", + "call_id": call_id, + "output": output_str + }); + self.conversation_history.push(output_item); + } +} + /// Helper that parses SSE frames from the OpenAI responses stream and /// accumulates enough information to persist the final response locally. struct StreamingResponseAccumulator { @@ -388,126 +450,32 @@ impl OpenAIRouter { obj.insert("store".to_string(), Value::Bool(original_body.store)); } - let mut final_response_json = openai_response_json; - - if let Some(mcp) = active_mcp { - if let Some((call_id, tool_name, args_json_str)) = - Self::extract_function_call(&final_response_json) - { - info!( - "Detected function call: name={}, call_id={}, args={}", - tool_name, call_id, args_json_str - ); - - let call_started = SystemTime::now(); - let call_result = - Self::execute_mcp_call(mcp, &tool_name, &args_json_str).await; - let call_duration_ms = - call_started.elapsed().unwrap_or_default().as_millis(); - - let (output_payload, call_ok, call_error) = match call_result { - Ok((server, out)) => { - info!( - call_id = %call_id, - tool_name = %tool_name, - server = %server, - duration_ms = call_duration_ms, - "MCP tool call succeeded" - ); - (out, true, None) - } - Err(err) => { - warn!( - call_id = %call_id, - tool_name = %tool_name, - duration_ms = call_duration_ms, - error = %err, - "MCP tool call failed" - ); - ( - serde_json::json!({ - "error": err - }) - .to_string(), - false, - Some(err), - ) - } - }; - + // If MCP is active and we detect a function call, enter the tool loop + let mut final_response_json = if let Some(mcp) = active_mcp { + if Self::extract_function_call(&openai_response_json).is_some() { + // Use the loop to handle potentially multiple tool calls + let loop_config = McpLoopConfig::default(); match self - .resume_with_tool_result(ResumeWithToolArgs { - url: &url, + .execute_tool_loop( + &url, headers, - original_payload: &payload, - call_id: &call_id, - tool_name: &tool_name, - args_json_str: &args_json_str, - output_str: &output_payload, + payload.clone(), original_body, - }) + mcp, + &loop_config, + ) .await { - Ok(mut resumed_json) => { - // Inject MCP output items (mcp_list_tools and mcp_call) - let server_label = original_body - .tools - .iter() - .find(|t| matches!(t.r#type, ResponseToolType::Mcp)) - .and_then(|t| t.server_label.as_deref()) - .unwrap_or("mcp"); - - if let Err(inject_err) = Self::inject_mcp_output_items( - &mut resumed_json, - mcp, - McpOutputItemsArgs { - tool_name: &tool_name, - args_json: &args_json_str, - output: &output_payload, - server_label, - success: call_ok, - error: call_error.as_deref(), - }, - ) { - warn!( - "Failed to inject MCP output items: {}", - inject_err - ); - } - - if !call_ok { - if let Some(obj) = resumed_json.as_object_mut() { - let metadata_value = - obj.entry("metadata").or_insert_with(|| { - Value::Object(serde_json::Map::new()) - }); - if let Some(metadata) = - metadata_value.as_object_mut() - { - if let Some(err_msg) = call_error.as_ref() { - metadata.insert( - "mcp_error".to_string(), - Value::String(err_msg.clone()), - ); - } - } - } - } - final_response_json = resumed_json; - } + Ok(loop_result) => loop_result, Err(err) => { - warn!("Failed to resume with tool result: {}", err); + warn!("Tool loop failed: {}", err); let error_body = json!({ "error": { - "message": format!( - "Failed to resume with tool result: {}", - err - ), + "message": format!("Tool loop failed: {}", err), "type": "internal_error", } }) .to_string(); - return ( StatusCode::INTERNAL_SERVER_ERROR, [("content-type", "application/json")], @@ -517,10 +485,14 @@ impl OpenAIRouter { } } } else { - info!("No function call found in upstream response; skipping MCP"); + // No function call detected, use response as-is + openai_response_json } - } + } else { + openai_response_json + }; + // Mask tools back to MCP format for client Self::mask_tools_as_mcp(&mut final_response_json, original_body); if original_body.store { if let Err(e) = self @@ -1040,26 +1012,6 @@ impl OpenAIRouter { } } -struct ResumeWithToolArgs<'a> { - url: &'a str, - headers: Option<&'a HeaderMap>, - original_payload: &'a Value, - call_id: &'a str, - tool_name: &'a str, - args_json_str: &'a str, - output_str: &'a str, - original_body: &'a ResponsesRequest, -} - -struct McpOutputItemsArgs<'a> { - tool_name: &'a str, - args_json: &'a str, - output: &'a str, - server_label: &'a str, - success: bool, - error: Option<&'a str>, -} - impl OpenAIRouter { fn extract_function_call(resp: &Value) -> Option<(String, String, String)> { let output = resp.get("output")?.as_array()?; @@ -1150,6 +1102,375 @@ impl OpenAIRouter { Ok((server_name, output_str)) } + /// Build a resume payload with conversation history + fn build_resume_payload( + base_payload: &Value, + conversation_history: &[Value], + original_input: &ResponseInput, + tools_json: &Value, + ) -> Result { + // Clone the base payload which already has cleaned fields + let mut payload = base_payload.clone(); + + let obj = payload + .as_object_mut() + .ok_or_else(|| "payload not an object".to_string())?; + + // Build input array: start with original user input + let mut input_array = Vec::new(); + + // Add original user message + // For structured input, serialize the original input items + match original_input { + ResponseInput::Text(text) => { + let user_item = json!({ + "type": "message", + "role": "user", + "content": [{ "type": "input_text", "text": text }] + }); + input_array.push(user_item); + } + ResponseInput::Items(items) => { + // Items are already structured ResponseInputOutputItem, convert to JSON + if let Ok(items_value) = serde_json::to_value(items) { + if let Some(items_arr) = items_value.as_array() { + input_array.extend_from_slice(items_arr); + } + } + } + } + + // Add all conversation history (function calls and outputs) + input_array.extend_from_slice(conversation_history); + + obj.insert("input".to_string(), Value::Array(input_array)); + + // Use the transformed tools (function tools, not MCP tools) + if let Some(tools_arr) = tools_json.as_array() { + if !tools_arr.is_empty() { + obj.insert("tools".to_string(), tools_json.clone()); + } + } + + // Ensure non-streaming and no store to upstream + obj.insert("stream".to_string(), Value::Bool(false)); + obj.insert("store".to_string(), Value::Bool(false)); + + // Note: SGLang-specific fields were already removed from base_payload + // before it was passed to execute_tool_loop (see route_responses lines 1935-1946) + + Ok(payload) + } + + /// Helper function to build mcp_call items from executed tool calls in conversation history + fn build_executed_mcp_call_items( + conversation_history: &[Value], + server_label: &str, + ) -> Vec { + let mut mcp_call_items = Vec::new(); + + for item in conversation_history { + if item.get("type").and_then(|t| t.as_str()) == Some("function_call") { + let call_id = item.get("call_id").and_then(|v| v.as_str()).unwrap_or(""); + let tool_name = item.get("name").and_then(|v| v.as_str()).unwrap_or(""); + let args = item + .get("arguments") + .and_then(|v| v.as_str()) + .unwrap_or("{}"); + + // Find corresponding output + let output_item = conversation_history.iter().find(|o| { + o.get("type").and_then(|t| t.as_str()) == Some("function_call_output") + && o.get("call_id").and_then(|c| c.as_str()) == Some(call_id) + }); + + let output_str = output_item + .and_then(|o| o.get("output").and_then(|v| v.as_str())) + .unwrap_or("{}"); + + // Check if output contains error by parsing JSON + let is_error = serde_json::from_str::(output_str) + .map(|v| v.get("error").is_some()) + .unwrap_or(false); + + let mcp_call_item = Self::build_mcp_call_item( + tool_name, + args, + output_str, + server_label, + !is_error, + if is_error { + Some("Tool execution failed") + } else { + None + }, + ); + mcp_call_items.push(mcp_call_item); + } + } + + mcp_call_items + } + + /// Build an incomplete response when limits are exceeded + fn build_incomplete_response( + mut response: Value, + state: ToolLoopState, + reason: &str, + active_mcp: &Arc, + original_body: &ResponsesRequest, + ) -> Result { + let obj = response + .as_object_mut() + .ok_or_else(|| "response not an object".to_string())?; + + // Set status to completed (not failed - partial success) + obj.insert("status".to_string(), Value::String("completed".to_string())); + + // Set incomplete_details + obj.insert( + "incomplete_details".to_string(), + json!({ "reason": reason }), + ); + + // Convert any function_call in output to mcp_call format + if let Some(output_array) = obj.get_mut("output").and_then(|v| v.as_array_mut()) { + let server_label = original_body + .tools + .iter() + .find(|t| matches!(t.r#type, ResponseToolType::Mcp)) + .and_then(|t| t.server_label.as_deref()) + .unwrap_or("mcp"); + + // Find any function_call items and convert them to mcp_call (incomplete) + let mut mcp_call_items = Vec::new(); + for item in output_array.iter() { + if item.get("type").and_then(|t| t.as_str()) == Some("function_tool_call") { + let tool_name = item.get("name").and_then(|v| v.as_str()).unwrap_or(""); + let args = item + .get("arguments") + .and_then(|v| v.as_str()) + .unwrap_or("{}"); + + // Mark as incomplete - not executed + let mcp_call_item = Self::build_mcp_call_item( + tool_name, + args, + "", // No output - wasn't executed + server_label, + false, // Not successful + Some("Not executed - response stopped due to limit"), + ); + mcp_call_items.push(mcp_call_item); + } + } + + // Add mcp_list_tools and executed mcp_call items at the beginning + if state.total_calls > 0 || !mcp_call_items.is_empty() { + let list_tools_item = Self::build_mcp_list_tools_item(active_mcp, server_label); + output_array.insert(0, list_tools_item); + + // Add mcp_call items for executed calls using helper + let executed_items = + Self::build_executed_mcp_call_items(&state.conversation_history, server_label); + + let mut insert_pos = 1; + for item in executed_items { + output_array.insert(insert_pos, item); + insert_pos += 1; + } + + // Add incomplete mcp_call items + for item in mcp_call_items { + output_array.insert(insert_pos, item); + insert_pos += 1; + } + } + } + + // Add warning to metadata + if let Some(metadata_val) = obj.get_mut("metadata") { + if let Some(metadata_obj) = metadata_val.as_object_mut() { + if let Some(mcp_val) = metadata_obj.get_mut("mcp") { + if let Some(mcp_obj) = mcp_val.as_object_mut() { + mcp_obj.insert( + "truncation_warning".to_string(), + Value::String(format!( + "Loop terminated at {} iterations, {} total calls (reason: {})", + state.iteration, state.total_calls, reason + )), + ); + } + } + } + } + + Ok(response) + } + + /// Execute the tool calling loop + async fn execute_tool_loop( + &self, + url: &str, + headers: Option<&HeaderMap>, + initial_payload: Value, + original_body: &ResponsesRequest, + active_mcp: &Arc, + config: &McpLoopConfig, + ) -> Result { + let mut state = ToolLoopState::new(original_body.input.clone()); + + // Get max_tool_calls from request (None means no user-specified limit) + let max_tool_calls = original_body.max_tool_calls.map(|n| n as usize); + + // Keep initial_payload as base template (already has fields cleaned) + let base_payload = initial_payload.clone(); + let tools_json = base_payload.get("tools").cloned().unwrap_or(json!([])); + let mut current_payload = initial_payload; + + info!( + "Starting tool loop: max_tool_calls={:?}, max_iterations={}", + max_tool_calls, config.max_iterations + ); + + loop { + // Make request to upstream + let request_builder = self.client.post(url).json(¤t_payload); + let request_builder = if let Some(headers) = headers { + apply_request_headers(headers, request_builder, true) + } else { + request_builder + }; + + let response = request_builder + .send() + .await + .map_err(|e| format!("upstream request failed: {}", e))?; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + return Err(format!("upstream error {}: {}", status, body)); + } + + let mut response_json = response + .json::() + .await + .map_err(|e| format!("parse response: {}", e))?; + + // Check for function call + if let Some((call_id, tool_name, args_json_str)) = + Self::extract_function_call(&response_json) + { + state.iteration += 1; + state.total_calls += 1; + + info!( + "Tool loop iteration {}: calling {} (call_id: {})", + state.iteration, tool_name, call_id + ); + + // Check combined limit: use minimum of user's max_tool_calls (if set) and safety max_iterations + let effective_limit = match max_tool_calls { + Some(user_max) => user_max.min(config.max_iterations), + None => config.max_iterations, + }; + + if state.total_calls > effective_limit { + if let Some(user_max) = max_tool_calls { + if state.total_calls > user_max { + warn!("Reached user-specified max_tool_calls limit: {}", user_max); + } else { + warn!( + "Reached safety max_iterations limit: {}", + config.max_iterations + ); + } + } else { + warn!( + "Reached safety max_iterations limit: {}", + config.max_iterations + ); + } + + return Self::build_incomplete_response( + response_json, + state, + "max_tool_calls", + active_mcp, + original_body, + ); + } + + // Execute tool + let call_result = + Self::execute_mcp_call(active_mcp, &tool_name, &args_json_str).await; + + let output_str = match call_result { + Ok((_, output)) => output, + Err(err) => { + warn!("Tool execution failed: {}", err); + // Return error as output, let model decide how to proceed + json!({ "error": err }).to_string() + } + }; + + // Record the call + state.record_call(call_id, tool_name, args_json_str, output_str); + + // Build resume payload + current_payload = Self::build_resume_payload( + &base_payload, + &state.conversation_history, + &state.original_input, + &tools_json, + )?; + } else { + // No more tool calls, we're done + info!( + "Tool loop completed: {} iterations, {} total calls", + state.iteration, state.total_calls + ); + + // Inject MCP output items if we executed any tools + if state.total_calls > 0 { + let server_label = original_body + .tools + .iter() + .find(|t| matches!(t.r#type, ResponseToolType::Mcp)) + .and_then(|t| t.server_label.as_deref()) + .unwrap_or("mcp"); + + // Build mcp_list_tools item + let list_tools_item = Self::build_mcp_list_tools_item(active_mcp, server_label); + + // Insert at beginning of output array + if let Some(output_array) = response_json + .get_mut("output") + .and_then(|v| v.as_array_mut()) + { + output_array.insert(0, list_tools_item); + + // Build mcp_call items using helper function + let mcp_call_items = Self::build_executed_mcp_call_items( + &state.conversation_history, + server_label, + ); + + // Insert mcp_call items after mcp_list_tools using mutable position + let mut insert_pos = 1; + for item in mcp_call_items { + output_array.insert(insert_pos, item); + insert_pos += 1; + } + } + } + + return Ok(response_json); + } + } + } + /// Generate a unique ID for MCP output items (similar to OpenAI format) fn generate_mcp_id(prefix: &str) -> String { use rand::RngCore; @@ -1213,113 +1534,6 @@ impl OpenAIRouter { "server_label": server_label }) } - - /// Inject mcp_list_tools and mcp_call items into the response output array - fn inject_mcp_output_items( - response_json: &mut Value, - mcp: &Arc, - args: McpOutputItemsArgs, - ) -> Result<(), String> { - let output_array = response_json - .get_mut("output") - .and_then(|v| v.as_array_mut()) - .ok_or("missing output array")?; - - // Build MCP output items - let list_tools_item = Self::build_mcp_list_tools_item(mcp, args.server_label); - let call_item = Self::build_mcp_call_item( - args.tool_name, - args.args_json, - args.output, - args.server_label, - args.success, - args.error, - ); - - // Find the index of the last message item to insert mcp_call before it - let call_insertion_index = output_array - .iter() - .rposition(|item| item.get("type").and_then(|v| v.as_str()) == Some("message")) - .unwrap_or(output_array.len()); - - // Insert items in-place for efficiency - output_array.insert(call_insertion_index, call_item); - output_array.insert(0, list_tools_item); - - Ok(()) - } - - async fn resume_with_tool_result(&self, args: ResumeWithToolArgs<'_>) -> Result { - let mut payload2 = args.original_payload.clone(); - let obj = payload2 - .as_object_mut() - .ok_or_else(|| "payload not an object".to_string())?; - - // Build function_call and tool result items per OpenAI Responses spec - let user_item = serde_json::json!({ - "type": "message", - "role": "user", - "content": args.original_body.input.clone() - }); - // temp system message since currently only support 1 turn of mcp function call - let system_item = serde_json::json!({ - "type": "message", - "role": "system", - "content": "please resume with the following tool result, and answer user's question directly, don't trigger any more tool calls" - }); - - let func_item = serde_json::json!({ - "type": "function_call", - "call_id": args.call_id, - "name": args.tool_name, - "arguments": args.args_json_str - }); - // Build tool result item as function_call_output per OpenAI Responses spec - let tool_item = serde_json::json!({ - "type": "function_call_output", - "call_id": args.call_id, - "output": args.output_str - }); - - obj.insert( - "input".to_string(), - Value::Array(vec![user_item, system_item, func_item, tool_item]), - ); - - // Ensure non-streaming and no store to upstream - obj.insert("stream".to_string(), Value::Bool(false)); - obj.insert("store".to_string(), Value::Bool(false)); - let mut req = self.client.post(args.url).json(&payload2); - if let Some(headers) = args.headers { - req = apply_request_headers(headers, req, true); - } - let resp = req - .send() - .await - .map_err(|e| format!("resume request failed: {}", e))?; - if !resp.status().is_success() { - let status = resp.status(); - let body = resp.text().await.unwrap_or_default(); - return Err(format!("resume upstream error {}: {}", status, body)); - } - let mut v = resp - .json::() - .await - .map_err(|e| format!("parse resume response: {}", e))?; - - if let Some(instr) = &args.original_body.instructions { - if let Some(obj) = v.as_object_mut() { - obj.entry("instructions") - .or_insert(Value::String(instr.clone())); - } - } - // After resume, mask tools as MCP if request used MCP - Self::mask_tools_as_mcp(&mut v, args.original_body); - if let Some(obj) = v.as_object_mut() { - obj.insert("store".to_string(), Value::Bool(args.original_body.store)); - } - Ok(v) - } } #[async_trait] diff --git a/sgl-router/tests/responses_api_test.rs b/sgl-router/tests/responses_api_test.rs index 80e853153..ffc71a5c4 100644 --- a/sgl-router/tests/responses_api_test.rs +++ b/sgl-router/tests/responses_api_test.rs @@ -252,6 +252,7 @@ fn test_responses_request_creation() { previous_response_id: None, reasoning: Some(ResponseReasoningParam { effort: Some(ReasoningEffort::Medium), + summary: None, }), service_tier: ServiceTier::Auto, store: true, @@ -380,6 +381,7 @@ fn test_usage_conversion() { fn test_reasoning_param_default() { let param = ResponseReasoningParam { effort: Some(ReasoningEffort::Medium), + summary: None, }; let json = serde_json::to_string(¶m).unwrap(); @@ -403,6 +405,7 @@ fn test_json_serialization() { previous_response_id: None, reasoning: Some(ResponseReasoningParam { effort: Some(ReasoningEffort::High), + summary: None, }), service_tier: ServiceTier::Priority, store: false, @@ -437,3 +440,328 @@ fn test_json_serialization() { assert!(parsed.stream); assert_eq!(parsed.tools.len(), 1); } + +#[tokio::test] +async fn test_multi_turn_loop_with_mcp() { + // This test verifies the multi-turn loop functionality: + // 1. Initial request with MCP tools + // 2. Mock worker returns function_call + // 3. Router executes MCP tool and resumes + // 4. Mock worker returns final answer + // 5. Verify the complete flow worked + + // Start mock MCP server + let mut mcp = MockMCPServer::start().await.expect("start mcp"); + + // Write a temp MCP config file + let mcp_yaml = format!( + "servers:\n - name: mock\n protocol: streamable\n url: {}\n", + mcp.url() + ); + let dir = tempfile::tempdir().expect("tmpdir"); + let cfg_path = dir.path().join("mcp.yaml"); + std::fs::write(&cfg_path, mcp_yaml).expect("write mcp cfg"); + std::env::set_var("SGLANG_MCP_CONFIG", cfg_path.to_str().unwrap()); + + // Start mock OpenAI worker + let mut worker = MockWorker::new(MockWorkerConfig { + port: 0, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }); + let worker_url = worker.start().await.expect("start worker"); + + // Build router config + let router_cfg = RouterConfig { + mode: RoutingMode::OpenAI { + worker_urls: vec![worker_url], + }, + connection_mode: ConnectionMode::Http, + policy: PolicyConfig::Random, + host: "127.0.0.1".to_string(), + port: 0, + max_payload_size: 8 * 1024 * 1024, + request_timeout_secs: 60, + worker_startup_timeout_secs: 5, + worker_startup_check_interval_secs: 1, + dp_aware: false, + api_key: None, + discovery: None, + metrics: None, + log_dir: None, + log_level: Some("info".to_string()), + request_id_headers: None, + max_concurrent_requests: 32, + queue_size: 0, + queue_timeout_secs: 5, + rate_limit_tokens_per_second: None, + cors_allowed_origins: vec![], + retry: RetryConfig::default(), + circuit_breaker: CircuitBreakerConfig::default(), + disable_retries: false, + disable_circuit_breaker: false, + health_check: HealthCheckConfig::default(), + enable_igw: false, + model_path: None, + tokenizer_path: None, + history_backend: sglang_router_rs::config::HistoryBackend::Memory, + oracle: None, + }; + + let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 64, None).expect("ctx"); + let router = RouterFactory::create_router(&Arc::new(ctx)) + .await + .expect("router"); + + // Build request with MCP tools + let req = ResponsesRequest { + background: false, + include: None, + input: ResponseInput::Text("search for SGLang".to_string()), + instructions: Some("Be helpful".to_string()), + max_output_tokens: Some(128), + max_tool_calls: None, // No limit - test unlimited + metadata: None, + model: Some("mock-model".to_string()), + parallel_tool_calls: true, + previous_response_id: None, + reasoning: None, + service_tier: ServiceTier::Auto, + store: true, + stream: false, + temperature: Some(0.7), + tool_choice: ToolChoice::Value(ToolChoiceValue::Auto), + tools: vec![ResponseTool { + r#type: ResponseToolType::Mcp, + server_url: Some(mcp.url()), + server_label: Some("mock".to_string()), + server_description: Some("Mock MCP server for testing".to_string()), + require_approval: Some("never".to_string()), + ..Default::default() + }], + top_logprobs: 0, + top_p: Some(1.0), + truncation: Truncation::Disabled, + user: None, + request_id: "resp_multi_turn_test".to_string(), + priority: 0, + frequency_penalty: 0.0, + presence_penalty: 0.0, + stop: None, + top_k: 50, + min_p: 0.0, + repetition_penalty: 1.0, + }; + + // Execute the request (this should trigger the multi-turn loop) + let response = router.route_responses(None, &req, None).await; + + // Check status + assert_eq!( + response.status(), + axum::http::StatusCode::OK, + "Request should succeed" + ); + + // Read the response body + use axum::body::to_bytes; + let response_body = response.into_body(); + let body_bytes = to_bytes(response_body, usize::MAX).await.unwrap(); + let response_json: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap(); + + println!( + "Multi-turn response: {}", + serde_json::to_string_pretty(&response_json).unwrap() + ); + + // Verify the response structure + assert_eq!(response_json["object"], "response"); + assert_eq!(response_json["status"], "completed"); + // Note: mock worker generates its own ID, so we just verify it exists + assert!( + response_json["id"].is_string(), + "Response should have an id" + ); + + // Check that output contains final message + let output = response_json["output"] + .as_array() + .expect("output should be array"); + assert!(!output.is_empty(), "output should not be empty"); + + // Find the final message with text + let has_final_text = output.iter().any(|item| { + item.get("type") + .and_then(|t| t.as_str()) + .map(|t| t == "message") + .unwrap_or(false) + && item + .get("content") + .and_then(|c| c.as_array()) + .map(|arr| { + arr.iter().any(|part| { + part.get("type") + .and_then(|t| t.as_str()) + .map(|t| t == "output_text") + .unwrap_or(false) + }) + }) + .unwrap_or(false) + }); + + assert!(has_final_text, "Should have final text output"); + + // Verify tools are masked back to MCP format + let tools = response_json["tools"] + .as_array() + .expect("tools should be array"); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0]["type"], "mcp"); + assert_eq!(tools[0]["server_label"], "mock"); + + // Clean up + std::env::remove_var("SGLANG_MCP_CONFIG"); + worker.stop().await; + mcp.stop().await; +} + +#[tokio::test] +async fn test_max_tool_calls_limit() { + // This test verifies that max_tool_calls is respected + // Note: The mock worker returns a final answer after one tool call, + // so with max_tool_calls=1, it completes normally (doesn't exceed the limit) + + let mut mcp = MockMCPServer::start().await.expect("start mcp"); + let mcp_yaml = format!( + "servers:\n - name: mock\n protocol: streamable\n url: {}\n", + mcp.url() + ); + let dir = tempfile::tempdir().expect("tmpdir"); + let cfg_path = dir.path().join("mcp.yaml"); + std::fs::write(&cfg_path, mcp_yaml).expect("write mcp cfg"); + std::env::set_var("SGLANG_MCP_CONFIG", cfg_path.to_str().unwrap()); + + let mut worker = MockWorker::new(MockWorkerConfig { + port: 0, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }); + let worker_url = worker.start().await.expect("start worker"); + + let router_cfg = RouterConfig { + mode: RoutingMode::OpenAI { + worker_urls: vec![worker_url], + }, + connection_mode: ConnectionMode::Http, + policy: PolicyConfig::Random, + host: "127.0.0.1".to_string(), + port: 0, + max_payload_size: 8 * 1024 * 1024, + request_timeout_secs: 60, + worker_startup_timeout_secs: 5, + worker_startup_check_interval_secs: 1, + dp_aware: false, + api_key: None, + discovery: None, + metrics: None, + log_dir: None, + log_level: Some("info".to_string()), + request_id_headers: None, + max_concurrent_requests: 32, + queue_size: 0, + queue_timeout_secs: 5, + rate_limit_tokens_per_second: None, + cors_allowed_origins: vec![], + retry: RetryConfig::default(), + circuit_breaker: CircuitBreakerConfig::default(), + disable_retries: false, + disable_circuit_breaker: false, + health_check: HealthCheckConfig::default(), + enable_igw: false, + model_path: None, + tokenizer_path: None, + history_backend: sglang_router_rs::config::HistoryBackend::Memory, + oracle: None, + }; + + let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 64, None).expect("ctx"); + let router = RouterFactory::create_router(&Arc::new(ctx)) + .await + .expect("router"); + + let req = ResponsesRequest { + background: false, + include: None, + input: ResponseInput::Text("test max calls".to_string()), + instructions: None, + max_output_tokens: Some(128), + max_tool_calls: Some(1), // Limit to 1 call + metadata: None, + model: Some("mock-model".to_string()), + parallel_tool_calls: true, + previous_response_id: None, + reasoning: None, + service_tier: ServiceTier::Auto, + store: false, + stream: false, + temperature: Some(0.7), + tool_choice: ToolChoice::Value(ToolChoiceValue::Auto), + tools: vec![ResponseTool { + r#type: ResponseToolType::Mcp, + server_url: Some(mcp.url()), + server_label: Some("mock".to_string()), + ..Default::default() + }], + top_logprobs: 0, + top_p: Some(1.0), + truncation: Truncation::Disabled, + user: None, + request_id: "resp_max_calls_test".to_string(), + priority: 0, + frequency_penalty: 0.0, + presence_penalty: 0.0, + stop: None, + top_k: 50, + min_p: 0.0, + repetition_penalty: 1.0, + }; + + let response = router.route_responses(None, &req, None).await; + assert_eq!(response.status(), axum::http::StatusCode::OK); + + use axum::body::to_bytes; + let response_body = response.into_body(); + let body_bytes = to_bytes(response_body, usize::MAX).await.unwrap(); + let response_json: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap(); + + println!( + "Max calls response: {}", + serde_json::to_string_pretty(&response_json).unwrap() + ); + + // With max_tool_calls=1, the mock returns a final answer after 1 call + // So it completes normally without exceeding the limit + assert_eq!(response_json["status"], "completed"); + + // Verify the basic response structure + assert!(response_json["id"].is_string()); + assert_eq!(response_json["object"], "response"); + + // The response should have tools masked back to MCP format + let tools = response_json["tools"] + .as_array() + .expect("tools should be array"); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0]["type"], "mcp"); + + // Note: To test actual limit exceeding, we would need a mock that keeps + // calling tools indefinitely, which would hit max_iterations (safety limit) + + std::env::remove_var("SGLANG_MCP_CONFIG"); + worker.stop().await; + mcp.stop().await; +}