[router] Add multi-turn tool calling loop support for MCP integration (#11143)

This commit is contained in:
Keyang Ru
2025-10-01 12:50:21 -07:00
committed by GitHub
parent 96fe2d0f15
commit a28b394fba
3 changed files with 791 additions and 238 deletions

View File

@@ -26,7 +26,6 @@ use std::{
collections::HashMap,
io,
sync::{atomic::AtomicBool, Arc},
time::SystemTime,
};
use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream;
@@ -57,6 +56,69 @@ impl std::fmt::Debug for OpenAIRouter {
}
}
/// Configuration for MCP tool calling loops
#[derive(Debug, Clone)]
struct McpLoopConfig {
/// Maximum iterations as safety limit (internal only, default: 10)
/// Prevents infinite loops when max_tool_calls is not set
max_iterations: usize,
}
impl Default for McpLoopConfig {
fn default() -> Self {
Self { max_iterations: 10 }
}
}
/// State for tracking multi-turn tool calling loop
struct ToolLoopState {
/// Current iteration number (starts at 0, increments with each tool call)
iteration: usize,
/// Total number of tool calls executed
total_calls: usize,
/// Conversation history (function_call and function_call_output items)
conversation_history: Vec<Value>,
/// Original user input (preserved for building resume payloads)
original_input: ResponseInput,
}
impl ToolLoopState {
fn new(original_input: ResponseInput) -> Self {
Self {
iteration: 0,
total_calls: 0,
conversation_history: Vec::new(),
original_input,
}
}
/// Record a tool call in the loop state
fn record_call(
&mut self,
call_id: String,
tool_name: String,
args_json_str: String,
output_str: String,
) {
// Add function_call item to history
let func_item = json!({
"type": "function_call",
"call_id": call_id,
"name": tool_name,
"arguments": args_json_str
});
self.conversation_history.push(func_item);
// Add function_call_output item to history
let output_item = json!({
"type": "function_call_output",
"call_id": call_id,
"output": output_str
});
self.conversation_history.push(output_item);
}
}
/// Helper that parses SSE frames from the OpenAI responses stream and
/// accumulates enough information to persist the final response locally.
struct StreamingResponseAccumulator {
@@ -388,126 +450,32 @@ impl OpenAIRouter {
obj.insert("store".to_string(), Value::Bool(original_body.store));
}
let mut final_response_json = openai_response_json;
if let Some(mcp) = active_mcp {
if let Some((call_id, tool_name, args_json_str)) =
Self::extract_function_call(&final_response_json)
{
info!(
"Detected function call: name={}, call_id={}, args={}",
tool_name, call_id, args_json_str
);
let call_started = SystemTime::now();
let call_result =
Self::execute_mcp_call(mcp, &tool_name, &args_json_str).await;
let call_duration_ms =
call_started.elapsed().unwrap_or_default().as_millis();
let (output_payload, call_ok, call_error) = match call_result {
Ok((server, out)) => {
info!(
call_id = %call_id,
tool_name = %tool_name,
server = %server,
duration_ms = call_duration_ms,
"MCP tool call succeeded"
);
(out, true, None)
}
Err(err) => {
warn!(
call_id = %call_id,
tool_name = %tool_name,
duration_ms = call_duration_ms,
error = %err,
"MCP tool call failed"
);
(
serde_json::json!({
"error": err
})
.to_string(),
false,
Some(err),
)
}
};
// If MCP is active and we detect a function call, enter the tool loop
let mut final_response_json = if let Some(mcp) = active_mcp {
if Self::extract_function_call(&openai_response_json).is_some() {
// Use the loop to handle potentially multiple tool calls
let loop_config = McpLoopConfig::default();
match self
.resume_with_tool_result(ResumeWithToolArgs {
url: &url,
.execute_tool_loop(
&url,
headers,
original_payload: &payload,
call_id: &call_id,
tool_name: &tool_name,
args_json_str: &args_json_str,
output_str: &output_payload,
payload.clone(),
original_body,
})
mcp,
&loop_config,
)
.await
{
Ok(mut resumed_json) => {
// Inject MCP output items (mcp_list_tools and mcp_call)
let server_label = original_body
.tools
.iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp))
.and_then(|t| t.server_label.as_deref())
.unwrap_or("mcp");
if let Err(inject_err) = Self::inject_mcp_output_items(
&mut resumed_json,
mcp,
McpOutputItemsArgs {
tool_name: &tool_name,
args_json: &args_json_str,
output: &output_payload,
server_label,
success: call_ok,
error: call_error.as_deref(),
},
) {
warn!(
"Failed to inject MCP output items: {}",
inject_err
);
}
if !call_ok {
if let Some(obj) = resumed_json.as_object_mut() {
let metadata_value =
obj.entry("metadata").or_insert_with(|| {
Value::Object(serde_json::Map::new())
});
if let Some(metadata) =
metadata_value.as_object_mut()
{
if let Some(err_msg) = call_error.as_ref() {
metadata.insert(
"mcp_error".to_string(),
Value::String(err_msg.clone()),
);
}
}
}
}
final_response_json = resumed_json;
}
Ok(loop_result) => loop_result,
Err(err) => {
warn!("Failed to resume with tool result: {}", err);
warn!("Tool loop failed: {}", err);
let error_body = json!({
"error": {
"message": format!(
"Failed to resume with tool result: {}",
err
),
"message": format!("Tool loop failed: {}", err),
"type": "internal_error",
}
})
.to_string();
return (
StatusCode::INTERNAL_SERVER_ERROR,
[("content-type", "application/json")],
@@ -517,10 +485,14 @@ impl OpenAIRouter {
}
}
} else {
info!("No function call found in upstream response; skipping MCP");
// No function call detected, use response as-is
openai_response_json
}
}
} else {
openai_response_json
};
// Mask tools back to MCP format for client
Self::mask_tools_as_mcp(&mut final_response_json, original_body);
if original_body.store {
if let Err(e) = self
@@ -1040,26 +1012,6 @@ impl OpenAIRouter {
}
}
struct ResumeWithToolArgs<'a> {
url: &'a str,
headers: Option<&'a HeaderMap>,
original_payload: &'a Value,
call_id: &'a str,
tool_name: &'a str,
args_json_str: &'a str,
output_str: &'a str,
original_body: &'a ResponsesRequest,
}
struct McpOutputItemsArgs<'a> {
tool_name: &'a str,
args_json: &'a str,
output: &'a str,
server_label: &'a str,
success: bool,
error: Option<&'a str>,
}
impl OpenAIRouter {
fn extract_function_call(resp: &Value) -> Option<(String, String, String)> {
let output = resp.get("output")?.as_array()?;
@@ -1150,6 +1102,375 @@ impl OpenAIRouter {
Ok((server_name, output_str))
}
/// Build a resume payload with conversation history
fn build_resume_payload(
base_payload: &Value,
conversation_history: &[Value],
original_input: &ResponseInput,
tools_json: &Value,
) -> Result<Value, String> {
// Clone the base payload which already has cleaned fields
let mut payload = base_payload.clone();
let obj = payload
.as_object_mut()
.ok_or_else(|| "payload not an object".to_string())?;
// Build input array: start with original user input
let mut input_array = Vec::new();
// Add original user message
// For structured input, serialize the original input items
match original_input {
ResponseInput::Text(text) => {
let user_item = json!({
"type": "message",
"role": "user",
"content": [{ "type": "input_text", "text": text }]
});
input_array.push(user_item);
}
ResponseInput::Items(items) => {
// Items are already structured ResponseInputOutputItem, convert to JSON
if let Ok(items_value) = serde_json::to_value(items) {
if let Some(items_arr) = items_value.as_array() {
input_array.extend_from_slice(items_arr);
}
}
}
}
// Add all conversation history (function calls and outputs)
input_array.extend_from_slice(conversation_history);
obj.insert("input".to_string(), Value::Array(input_array));
// Use the transformed tools (function tools, not MCP tools)
if let Some(tools_arr) = tools_json.as_array() {
if !tools_arr.is_empty() {
obj.insert("tools".to_string(), tools_json.clone());
}
}
// Ensure non-streaming and no store to upstream
obj.insert("stream".to_string(), Value::Bool(false));
obj.insert("store".to_string(), Value::Bool(false));
// Note: SGLang-specific fields were already removed from base_payload
// before it was passed to execute_tool_loop (see route_responses lines 1935-1946)
Ok(payload)
}
/// Helper function to build mcp_call items from executed tool calls in conversation history
fn build_executed_mcp_call_items(
conversation_history: &[Value],
server_label: &str,
) -> Vec<Value> {
let mut mcp_call_items = Vec::new();
for item in conversation_history {
if item.get("type").and_then(|t| t.as_str()) == Some("function_call") {
let call_id = item.get("call_id").and_then(|v| v.as_str()).unwrap_or("");
let tool_name = item.get("name").and_then(|v| v.as_str()).unwrap_or("");
let args = item
.get("arguments")
.and_then(|v| v.as_str())
.unwrap_or("{}");
// Find corresponding output
let output_item = conversation_history.iter().find(|o| {
o.get("type").and_then(|t| t.as_str()) == Some("function_call_output")
&& o.get("call_id").and_then(|c| c.as_str()) == Some(call_id)
});
let output_str = output_item
.and_then(|o| o.get("output").and_then(|v| v.as_str()))
.unwrap_or("{}");
// Check if output contains error by parsing JSON
let is_error = serde_json::from_str::<serde_json::Value>(output_str)
.map(|v| v.get("error").is_some())
.unwrap_or(false);
let mcp_call_item = Self::build_mcp_call_item(
tool_name,
args,
output_str,
server_label,
!is_error,
if is_error {
Some("Tool execution failed")
} else {
None
},
);
mcp_call_items.push(mcp_call_item);
}
}
mcp_call_items
}
/// Build an incomplete response when limits are exceeded
fn build_incomplete_response(
mut response: Value,
state: ToolLoopState,
reason: &str,
active_mcp: &Arc<crate::mcp::McpClientManager>,
original_body: &ResponsesRequest,
) -> Result<Value, String> {
let obj = response
.as_object_mut()
.ok_or_else(|| "response not an object".to_string())?;
// Set status to completed (not failed - partial success)
obj.insert("status".to_string(), Value::String("completed".to_string()));
// Set incomplete_details
obj.insert(
"incomplete_details".to_string(),
json!({ "reason": reason }),
);
// Convert any function_call in output to mcp_call format
if let Some(output_array) = obj.get_mut("output").and_then(|v| v.as_array_mut()) {
let server_label = original_body
.tools
.iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp))
.and_then(|t| t.server_label.as_deref())
.unwrap_or("mcp");
// Find any function_call items and convert them to mcp_call (incomplete)
let mut mcp_call_items = Vec::new();
for item in output_array.iter() {
if item.get("type").and_then(|t| t.as_str()) == Some("function_tool_call") {
let tool_name = item.get("name").and_then(|v| v.as_str()).unwrap_or("");
let args = item
.get("arguments")
.and_then(|v| v.as_str())
.unwrap_or("{}");
// Mark as incomplete - not executed
let mcp_call_item = Self::build_mcp_call_item(
tool_name,
args,
"", // No output - wasn't executed
server_label,
false, // Not successful
Some("Not executed - response stopped due to limit"),
);
mcp_call_items.push(mcp_call_item);
}
}
// Add mcp_list_tools and executed mcp_call items at the beginning
if state.total_calls > 0 || !mcp_call_items.is_empty() {
let list_tools_item = Self::build_mcp_list_tools_item(active_mcp, server_label);
output_array.insert(0, list_tools_item);
// Add mcp_call items for executed calls using helper
let executed_items =
Self::build_executed_mcp_call_items(&state.conversation_history, server_label);
let mut insert_pos = 1;
for item in executed_items {
output_array.insert(insert_pos, item);
insert_pos += 1;
}
// Add incomplete mcp_call items
for item in mcp_call_items {
output_array.insert(insert_pos, item);
insert_pos += 1;
}
}
}
// Add warning to metadata
if let Some(metadata_val) = obj.get_mut("metadata") {
if let Some(metadata_obj) = metadata_val.as_object_mut() {
if let Some(mcp_val) = metadata_obj.get_mut("mcp") {
if let Some(mcp_obj) = mcp_val.as_object_mut() {
mcp_obj.insert(
"truncation_warning".to_string(),
Value::String(format!(
"Loop terminated at {} iterations, {} total calls (reason: {})",
state.iteration, state.total_calls, reason
)),
);
}
}
}
}
Ok(response)
}
/// Execute the tool calling loop
async fn execute_tool_loop(
&self,
url: &str,
headers: Option<&HeaderMap>,
initial_payload: Value,
original_body: &ResponsesRequest,
active_mcp: &Arc<crate::mcp::McpClientManager>,
config: &McpLoopConfig,
) -> Result<Value, String> {
let mut state = ToolLoopState::new(original_body.input.clone());
// Get max_tool_calls from request (None means no user-specified limit)
let max_tool_calls = original_body.max_tool_calls.map(|n| n as usize);
// Keep initial_payload as base template (already has fields cleaned)
let base_payload = initial_payload.clone();
let tools_json = base_payload.get("tools").cloned().unwrap_or(json!([]));
let mut current_payload = initial_payload;
info!(
"Starting tool loop: max_tool_calls={:?}, max_iterations={}",
max_tool_calls, config.max_iterations
);
loop {
// Make request to upstream
let request_builder = self.client.post(url).json(&current_payload);
let request_builder = if let Some(headers) = headers {
apply_request_headers(headers, request_builder, true)
} else {
request_builder
};
let response = request_builder
.send()
.await
.map_err(|e| format!("upstream request failed: {}", e))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(format!("upstream error {}: {}", status, body));
}
let mut response_json = response
.json::<Value>()
.await
.map_err(|e| format!("parse response: {}", e))?;
// Check for function call
if let Some((call_id, tool_name, args_json_str)) =
Self::extract_function_call(&response_json)
{
state.iteration += 1;
state.total_calls += 1;
info!(
"Tool loop iteration {}: calling {} (call_id: {})",
state.iteration, tool_name, call_id
);
// Check combined limit: use minimum of user's max_tool_calls (if set) and safety max_iterations
let effective_limit = match max_tool_calls {
Some(user_max) => user_max.min(config.max_iterations),
None => config.max_iterations,
};
if state.total_calls > effective_limit {
if let Some(user_max) = max_tool_calls {
if state.total_calls > user_max {
warn!("Reached user-specified max_tool_calls limit: {}", user_max);
} else {
warn!(
"Reached safety max_iterations limit: {}",
config.max_iterations
);
}
} else {
warn!(
"Reached safety max_iterations limit: {}",
config.max_iterations
);
}
return Self::build_incomplete_response(
response_json,
state,
"max_tool_calls",
active_mcp,
original_body,
);
}
// Execute tool
let call_result =
Self::execute_mcp_call(active_mcp, &tool_name, &args_json_str).await;
let output_str = match call_result {
Ok((_, output)) => output,
Err(err) => {
warn!("Tool execution failed: {}", err);
// Return error as output, let model decide how to proceed
json!({ "error": err }).to_string()
}
};
// Record the call
state.record_call(call_id, tool_name, args_json_str, output_str);
// Build resume payload
current_payload = Self::build_resume_payload(
&base_payload,
&state.conversation_history,
&state.original_input,
&tools_json,
)?;
} else {
// No more tool calls, we're done
info!(
"Tool loop completed: {} iterations, {} total calls",
state.iteration, state.total_calls
);
// Inject MCP output items if we executed any tools
if state.total_calls > 0 {
let server_label = original_body
.tools
.iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp))
.and_then(|t| t.server_label.as_deref())
.unwrap_or("mcp");
// Build mcp_list_tools item
let list_tools_item = Self::build_mcp_list_tools_item(active_mcp, server_label);
// Insert at beginning of output array
if let Some(output_array) = response_json
.get_mut("output")
.and_then(|v| v.as_array_mut())
{
output_array.insert(0, list_tools_item);
// Build mcp_call items using helper function
let mcp_call_items = Self::build_executed_mcp_call_items(
&state.conversation_history,
server_label,
);
// Insert mcp_call items after mcp_list_tools using mutable position
let mut insert_pos = 1;
for item in mcp_call_items {
output_array.insert(insert_pos, item);
insert_pos += 1;
}
}
}
return Ok(response_json);
}
}
}
/// Generate a unique ID for MCP output items (similar to OpenAI format)
fn generate_mcp_id(prefix: &str) -> String {
use rand::RngCore;
@@ -1213,113 +1534,6 @@ impl OpenAIRouter {
"server_label": server_label
})
}
/// Inject mcp_list_tools and mcp_call items into the response output array
fn inject_mcp_output_items(
response_json: &mut Value,
mcp: &Arc<crate::mcp::McpClientManager>,
args: McpOutputItemsArgs,
) -> Result<(), String> {
let output_array = response_json
.get_mut("output")
.and_then(|v| v.as_array_mut())
.ok_or("missing output array")?;
// Build MCP output items
let list_tools_item = Self::build_mcp_list_tools_item(mcp, args.server_label);
let call_item = Self::build_mcp_call_item(
args.tool_name,
args.args_json,
args.output,
args.server_label,
args.success,
args.error,
);
// Find the index of the last message item to insert mcp_call before it
let call_insertion_index = output_array
.iter()
.rposition(|item| item.get("type").and_then(|v| v.as_str()) == Some("message"))
.unwrap_or(output_array.len());
// Insert items in-place for efficiency
output_array.insert(call_insertion_index, call_item);
output_array.insert(0, list_tools_item);
Ok(())
}
async fn resume_with_tool_result(&self, args: ResumeWithToolArgs<'_>) -> Result<Value, String> {
let mut payload2 = args.original_payload.clone();
let obj = payload2
.as_object_mut()
.ok_or_else(|| "payload not an object".to_string())?;
// Build function_call and tool result items per OpenAI Responses spec
let user_item = serde_json::json!({
"type": "message",
"role": "user",
"content": args.original_body.input.clone()
});
// temp system message since currently only support 1 turn of mcp function call
let system_item = serde_json::json!({
"type": "message",
"role": "system",
"content": "please resume with the following tool result, and answer user's question directly, don't trigger any more tool calls"
});
let func_item = serde_json::json!({
"type": "function_call",
"call_id": args.call_id,
"name": args.tool_name,
"arguments": args.args_json_str
});
// Build tool result item as function_call_output per OpenAI Responses spec
let tool_item = serde_json::json!({
"type": "function_call_output",
"call_id": args.call_id,
"output": args.output_str
});
obj.insert(
"input".to_string(),
Value::Array(vec![user_item, system_item, func_item, tool_item]),
);
// Ensure non-streaming and no store to upstream
obj.insert("stream".to_string(), Value::Bool(false));
obj.insert("store".to_string(), Value::Bool(false));
let mut req = self.client.post(args.url).json(&payload2);
if let Some(headers) = args.headers {
req = apply_request_headers(headers, req, true);
}
let resp = req
.send()
.await
.map_err(|e| format!("resume request failed: {}", e))?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(format!("resume upstream error {}: {}", status, body));
}
let mut v = resp
.json::<Value>()
.await
.map_err(|e| format!("parse resume response: {}", e))?;
if let Some(instr) = &args.original_body.instructions {
if let Some(obj) = v.as_object_mut() {
obj.entry("instructions")
.or_insert(Value::String(instr.clone()));
}
}
// After resume, mask tools as MCP if request used MCP
Self::mask_tools_as_mcp(&mut v, args.original_body);
if let Some(obj) = v.as_object_mut() {
obj.insert("store".to_string(), Value::Bool(args.original_body.store));
}
Ok(v)
}
}
#[async_trait]