[router] Add multi-turn tool calling loop support for MCP integration (#11143)
This commit is contained in:
@@ -723,7 +723,10 @@ pub enum ResponseToolType {
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ResponseReasoningParam {
|
||||
#[serde(default = "default_reasoning_effort")]
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub effort: Option<ReasoningEffort>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub summary: Option<ReasoningSummary>,
|
||||
}
|
||||
|
||||
fn default_reasoning_effort() -> Option<ReasoningEffort> {
|
||||
@@ -738,6 +741,14 @@ pub enum ReasoningEffort {
|
||||
High,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ReasoningSummary {
|
||||
Auto,
|
||||
Concise,
|
||||
Detailed,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(tag = "type")]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
|
||||
@@ -26,7 +26,6 @@ use std::{
|
||||
collections::HashMap,
|
||||
io,
|
||||
sync::{atomic::AtomicBool, Arc},
|
||||
time::SystemTime,
|
||||
};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
@@ -57,6 +56,69 @@ impl std::fmt::Debug for OpenAIRouter {
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for MCP tool calling loops
|
||||
#[derive(Debug, Clone)]
|
||||
struct McpLoopConfig {
|
||||
/// Maximum iterations as safety limit (internal only, default: 10)
|
||||
/// Prevents infinite loops when max_tool_calls is not set
|
||||
max_iterations: usize,
|
||||
}
|
||||
|
||||
impl Default for McpLoopConfig {
|
||||
fn default() -> Self {
|
||||
Self { max_iterations: 10 }
|
||||
}
|
||||
}
|
||||
|
||||
/// State for tracking multi-turn tool calling loop
|
||||
struct ToolLoopState {
|
||||
/// Current iteration number (starts at 0, increments with each tool call)
|
||||
iteration: usize,
|
||||
/// Total number of tool calls executed
|
||||
total_calls: usize,
|
||||
/// Conversation history (function_call and function_call_output items)
|
||||
conversation_history: Vec<Value>,
|
||||
/// Original user input (preserved for building resume payloads)
|
||||
original_input: ResponseInput,
|
||||
}
|
||||
|
||||
impl ToolLoopState {
|
||||
fn new(original_input: ResponseInput) -> Self {
|
||||
Self {
|
||||
iteration: 0,
|
||||
total_calls: 0,
|
||||
conversation_history: Vec::new(),
|
||||
original_input,
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a tool call in the loop state
|
||||
fn record_call(
|
||||
&mut self,
|
||||
call_id: String,
|
||||
tool_name: String,
|
||||
args_json_str: String,
|
||||
output_str: String,
|
||||
) {
|
||||
// Add function_call item to history
|
||||
let func_item = json!({
|
||||
"type": "function_call",
|
||||
"call_id": call_id,
|
||||
"name": tool_name,
|
||||
"arguments": args_json_str
|
||||
});
|
||||
self.conversation_history.push(func_item);
|
||||
|
||||
// Add function_call_output item to history
|
||||
let output_item = json!({
|
||||
"type": "function_call_output",
|
||||
"call_id": call_id,
|
||||
"output": output_str
|
||||
});
|
||||
self.conversation_history.push(output_item);
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper that parses SSE frames from the OpenAI responses stream and
|
||||
/// accumulates enough information to persist the final response locally.
|
||||
struct StreamingResponseAccumulator {
|
||||
@@ -388,126 +450,32 @@ impl OpenAIRouter {
|
||||
obj.insert("store".to_string(), Value::Bool(original_body.store));
|
||||
}
|
||||
|
||||
let mut final_response_json = openai_response_json;
|
||||
|
||||
if let Some(mcp) = active_mcp {
|
||||
if let Some((call_id, tool_name, args_json_str)) =
|
||||
Self::extract_function_call(&final_response_json)
|
||||
{
|
||||
info!(
|
||||
"Detected function call: name={}, call_id={}, args={}",
|
||||
tool_name, call_id, args_json_str
|
||||
);
|
||||
|
||||
let call_started = SystemTime::now();
|
||||
let call_result =
|
||||
Self::execute_mcp_call(mcp, &tool_name, &args_json_str).await;
|
||||
let call_duration_ms =
|
||||
call_started.elapsed().unwrap_or_default().as_millis();
|
||||
|
||||
let (output_payload, call_ok, call_error) = match call_result {
|
||||
Ok((server, out)) => {
|
||||
info!(
|
||||
call_id = %call_id,
|
||||
tool_name = %tool_name,
|
||||
server = %server,
|
||||
duration_ms = call_duration_ms,
|
||||
"MCP tool call succeeded"
|
||||
);
|
||||
(out, true, None)
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(
|
||||
call_id = %call_id,
|
||||
tool_name = %tool_name,
|
||||
duration_ms = call_duration_ms,
|
||||
error = %err,
|
||||
"MCP tool call failed"
|
||||
);
|
||||
(
|
||||
serde_json::json!({
|
||||
"error": err
|
||||
})
|
||||
.to_string(),
|
||||
false,
|
||||
Some(err),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
// If MCP is active and we detect a function call, enter the tool loop
|
||||
let mut final_response_json = if let Some(mcp) = active_mcp {
|
||||
if Self::extract_function_call(&openai_response_json).is_some() {
|
||||
// Use the loop to handle potentially multiple tool calls
|
||||
let loop_config = McpLoopConfig::default();
|
||||
match self
|
||||
.resume_with_tool_result(ResumeWithToolArgs {
|
||||
url: &url,
|
||||
.execute_tool_loop(
|
||||
&url,
|
||||
headers,
|
||||
original_payload: &payload,
|
||||
call_id: &call_id,
|
||||
tool_name: &tool_name,
|
||||
args_json_str: &args_json_str,
|
||||
output_str: &output_payload,
|
||||
payload.clone(),
|
||||
original_body,
|
||||
})
|
||||
mcp,
|
||||
&loop_config,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(mut resumed_json) => {
|
||||
// Inject MCP output items (mcp_list_tools and mcp_call)
|
||||
let server_label = original_body
|
||||
.tools
|
||||
.iter()
|
||||
.find(|t| matches!(t.r#type, ResponseToolType::Mcp))
|
||||
.and_then(|t| t.server_label.as_deref())
|
||||
.unwrap_or("mcp");
|
||||
|
||||
if let Err(inject_err) = Self::inject_mcp_output_items(
|
||||
&mut resumed_json,
|
||||
mcp,
|
||||
McpOutputItemsArgs {
|
||||
tool_name: &tool_name,
|
||||
args_json: &args_json_str,
|
||||
output: &output_payload,
|
||||
server_label,
|
||||
success: call_ok,
|
||||
error: call_error.as_deref(),
|
||||
},
|
||||
) {
|
||||
warn!(
|
||||
"Failed to inject MCP output items: {}",
|
||||
inject_err
|
||||
);
|
||||
}
|
||||
|
||||
if !call_ok {
|
||||
if let Some(obj) = resumed_json.as_object_mut() {
|
||||
let metadata_value =
|
||||
obj.entry("metadata").or_insert_with(|| {
|
||||
Value::Object(serde_json::Map::new())
|
||||
});
|
||||
if let Some(metadata) =
|
||||
metadata_value.as_object_mut()
|
||||
{
|
||||
if let Some(err_msg) = call_error.as_ref() {
|
||||
metadata.insert(
|
||||
"mcp_error".to_string(),
|
||||
Value::String(err_msg.clone()),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
final_response_json = resumed_json;
|
||||
}
|
||||
Ok(loop_result) => loop_result,
|
||||
Err(err) => {
|
||||
warn!("Failed to resume with tool result: {}", err);
|
||||
warn!("Tool loop failed: {}", err);
|
||||
let error_body = json!({
|
||||
"error": {
|
||||
"message": format!(
|
||||
"Failed to resume with tool result: {}",
|
||||
err
|
||||
),
|
||||
"message": format!("Tool loop failed: {}", err),
|
||||
"type": "internal_error",
|
||||
}
|
||||
})
|
||||
.to_string();
|
||||
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
[("content-type", "application/json")],
|
||||
@@ -517,10 +485,14 @@ impl OpenAIRouter {
|
||||
}
|
||||
}
|
||||
} else {
|
||||
info!("No function call found in upstream response; skipping MCP");
|
||||
// No function call detected, use response as-is
|
||||
openai_response_json
|
||||
}
|
||||
}
|
||||
} else {
|
||||
openai_response_json
|
||||
};
|
||||
|
||||
// Mask tools back to MCP format for client
|
||||
Self::mask_tools_as_mcp(&mut final_response_json, original_body);
|
||||
if original_body.store {
|
||||
if let Err(e) = self
|
||||
@@ -1040,26 +1012,6 @@ impl OpenAIRouter {
|
||||
}
|
||||
}
|
||||
|
||||
struct ResumeWithToolArgs<'a> {
|
||||
url: &'a str,
|
||||
headers: Option<&'a HeaderMap>,
|
||||
original_payload: &'a Value,
|
||||
call_id: &'a str,
|
||||
tool_name: &'a str,
|
||||
args_json_str: &'a str,
|
||||
output_str: &'a str,
|
||||
original_body: &'a ResponsesRequest,
|
||||
}
|
||||
|
||||
struct McpOutputItemsArgs<'a> {
|
||||
tool_name: &'a str,
|
||||
args_json: &'a str,
|
||||
output: &'a str,
|
||||
server_label: &'a str,
|
||||
success: bool,
|
||||
error: Option<&'a str>,
|
||||
}
|
||||
|
||||
impl OpenAIRouter {
|
||||
fn extract_function_call(resp: &Value) -> Option<(String, String, String)> {
|
||||
let output = resp.get("output")?.as_array()?;
|
||||
@@ -1150,6 +1102,375 @@ impl OpenAIRouter {
|
||||
Ok((server_name, output_str))
|
||||
}
|
||||
|
||||
/// Build a resume payload with conversation history
|
||||
fn build_resume_payload(
|
||||
base_payload: &Value,
|
||||
conversation_history: &[Value],
|
||||
original_input: &ResponseInput,
|
||||
tools_json: &Value,
|
||||
) -> Result<Value, String> {
|
||||
// Clone the base payload which already has cleaned fields
|
||||
let mut payload = base_payload.clone();
|
||||
|
||||
let obj = payload
|
||||
.as_object_mut()
|
||||
.ok_or_else(|| "payload not an object".to_string())?;
|
||||
|
||||
// Build input array: start with original user input
|
||||
let mut input_array = Vec::new();
|
||||
|
||||
// Add original user message
|
||||
// For structured input, serialize the original input items
|
||||
match original_input {
|
||||
ResponseInput::Text(text) => {
|
||||
let user_item = json!({
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [{ "type": "input_text", "text": text }]
|
||||
});
|
||||
input_array.push(user_item);
|
||||
}
|
||||
ResponseInput::Items(items) => {
|
||||
// Items are already structured ResponseInputOutputItem, convert to JSON
|
||||
if let Ok(items_value) = serde_json::to_value(items) {
|
||||
if let Some(items_arr) = items_value.as_array() {
|
||||
input_array.extend_from_slice(items_arr);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add all conversation history (function calls and outputs)
|
||||
input_array.extend_from_slice(conversation_history);
|
||||
|
||||
obj.insert("input".to_string(), Value::Array(input_array));
|
||||
|
||||
// Use the transformed tools (function tools, not MCP tools)
|
||||
if let Some(tools_arr) = tools_json.as_array() {
|
||||
if !tools_arr.is_empty() {
|
||||
obj.insert("tools".to_string(), tools_json.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure non-streaming and no store to upstream
|
||||
obj.insert("stream".to_string(), Value::Bool(false));
|
||||
obj.insert("store".to_string(), Value::Bool(false));
|
||||
|
||||
// Note: SGLang-specific fields were already removed from base_payload
|
||||
// before it was passed to execute_tool_loop (see route_responses lines 1935-1946)
|
||||
|
||||
Ok(payload)
|
||||
}
|
||||
|
||||
/// Helper function to build mcp_call items from executed tool calls in conversation history
|
||||
fn build_executed_mcp_call_items(
|
||||
conversation_history: &[Value],
|
||||
server_label: &str,
|
||||
) -> Vec<Value> {
|
||||
let mut mcp_call_items = Vec::new();
|
||||
|
||||
for item in conversation_history {
|
||||
if item.get("type").and_then(|t| t.as_str()) == Some("function_call") {
|
||||
let call_id = item.get("call_id").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let tool_name = item.get("name").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let args = item
|
||||
.get("arguments")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("{}");
|
||||
|
||||
// Find corresponding output
|
||||
let output_item = conversation_history.iter().find(|o| {
|
||||
o.get("type").and_then(|t| t.as_str()) == Some("function_call_output")
|
||||
&& o.get("call_id").and_then(|c| c.as_str()) == Some(call_id)
|
||||
});
|
||||
|
||||
let output_str = output_item
|
||||
.and_then(|o| o.get("output").and_then(|v| v.as_str()))
|
||||
.unwrap_or("{}");
|
||||
|
||||
// Check if output contains error by parsing JSON
|
||||
let is_error = serde_json::from_str::<serde_json::Value>(output_str)
|
||||
.map(|v| v.get("error").is_some())
|
||||
.unwrap_or(false);
|
||||
|
||||
let mcp_call_item = Self::build_mcp_call_item(
|
||||
tool_name,
|
||||
args,
|
||||
output_str,
|
||||
server_label,
|
||||
!is_error,
|
||||
if is_error {
|
||||
Some("Tool execution failed")
|
||||
} else {
|
||||
None
|
||||
},
|
||||
);
|
||||
mcp_call_items.push(mcp_call_item);
|
||||
}
|
||||
}
|
||||
|
||||
mcp_call_items
|
||||
}
|
||||
|
||||
/// Build an incomplete response when limits are exceeded
|
||||
fn build_incomplete_response(
|
||||
mut response: Value,
|
||||
state: ToolLoopState,
|
||||
reason: &str,
|
||||
active_mcp: &Arc<crate::mcp::McpClientManager>,
|
||||
original_body: &ResponsesRequest,
|
||||
) -> Result<Value, String> {
|
||||
let obj = response
|
||||
.as_object_mut()
|
||||
.ok_or_else(|| "response not an object".to_string())?;
|
||||
|
||||
// Set status to completed (not failed - partial success)
|
||||
obj.insert("status".to_string(), Value::String("completed".to_string()));
|
||||
|
||||
// Set incomplete_details
|
||||
obj.insert(
|
||||
"incomplete_details".to_string(),
|
||||
json!({ "reason": reason }),
|
||||
);
|
||||
|
||||
// Convert any function_call in output to mcp_call format
|
||||
if let Some(output_array) = obj.get_mut("output").and_then(|v| v.as_array_mut()) {
|
||||
let server_label = original_body
|
||||
.tools
|
||||
.iter()
|
||||
.find(|t| matches!(t.r#type, ResponseToolType::Mcp))
|
||||
.and_then(|t| t.server_label.as_deref())
|
||||
.unwrap_or("mcp");
|
||||
|
||||
// Find any function_call items and convert them to mcp_call (incomplete)
|
||||
let mut mcp_call_items = Vec::new();
|
||||
for item in output_array.iter() {
|
||||
if item.get("type").and_then(|t| t.as_str()) == Some("function_tool_call") {
|
||||
let tool_name = item.get("name").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let args = item
|
||||
.get("arguments")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("{}");
|
||||
|
||||
// Mark as incomplete - not executed
|
||||
let mcp_call_item = Self::build_mcp_call_item(
|
||||
tool_name,
|
||||
args,
|
||||
"", // No output - wasn't executed
|
||||
server_label,
|
||||
false, // Not successful
|
||||
Some("Not executed - response stopped due to limit"),
|
||||
);
|
||||
mcp_call_items.push(mcp_call_item);
|
||||
}
|
||||
}
|
||||
|
||||
// Add mcp_list_tools and executed mcp_call items at the beginning
|
||||
if state.total_calls > 0 || !mcp_call_items.is_empty() {
|
||||
let list_tools_item = Self::build_mcp_list_tools_item(active_mcp, server_label);
|
||||
output_array.insert(0, list_tools_item);
|
||||
|
||||
// Add mcp_call items for executed calls using helper
|
||||
let executed_items =
|
||||
Self::build_executed_mcp_call_items(&state.conversation_history, server_label);
|
||||
|
||||
let mut insert_pos = 1;
|
||||
for item in executed_items {
|
||||
output_array.insert(insert_pos, item);
|
||||
insert_pos += 1;
|
||||
}
|
||||
|
||||
// Add incomplete mcp_call items
|
||||
for item in mcp_call_items {
|
||||
output_array.insert(insert_pos, item);
|
||||
insert_pos += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add warning to metadata
|
||||
if let Some(metadata_val) = obj.get_mut("metadata") {
|
||||
if let Some(metadata_obj) = metadata_val.as_object_mut() {
|
||||
if let Some(mcp_val) = metadata_obj.get_mut("mcp") {
|
||||
if let Some(mcp_obj) = mcp_val.as_object_mut() {
|
||||
mcp_obj.insert(
|
||||
"truncation_warning".to_string(),
|
||||
Value::String(format!(
|
||||
"Loop terminated at {} iterations, {} total calls (reason: {})",
|
||||
state.iteration, state.total_calls, reason
|
||||
)),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
/// Execute the tool calling loop
|
||||
async fn execute_tool_loop(
|
||||
&self,
|
||||
url: &str,
|
||||
headers: Option<&HeaderMap>,
|
||||
initial_payload: Value,
|
||||
original_body: &ResponsesRequest,
|
||||
active_mcp: &Arc<crate::mcp::McpClientManager>,
|
||||
config: &McpLoopConfig,
|
||||
) -> Result<Value, String> {
|
||||
let mut state = ToolLoopState::new(original_body.input.clone());
|
||||
|
||||
// Get max_tool_calls from request (None means no user-specified limit)
|
||||
let max_tool_calls = original_body.max_tool_calls.map(|n| n as usize);
|
||||
|
||||
// Keep initial_payload as base template (already has fields cleaned)
|
||||
let base_payload = initial_payload.clone();
|
||||
let tools_json = base_payload.get("tools").cloned().unwrap_or(json!([]));
|
||||
let mut current_payload = initial_payload;
|
||||
|
||||
info!(
|
||||
"Starting tool loop: max_tool_calls={:?}, max_iterations={}",
|
||||
max_tool_calls, config.max_iterations
|
||||
);
|
||||
|
||||
loop {
|
||||
// Make request to upstream
|
||||
let request_builder = self.client.post(url).json(¤t_payload);
|
||||
let request_builder = if let Some(headers) = headers {
|
||||
apply_request_headers(headers, request_builder, true)
|
||||
} else {
|
||||
request_builder
|
||||
};
|
||||
|
||||
let response = request_builder
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("upstream request failed: {}", e))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let body = response.text().await.unwrap_or_default();
|
||||
return Err(format!("upstream error {}: {}", status, body));
|
||||
}
|
||||
|
||||
let mut response_json = response
|
||||
.json::<Value>()
|
||||
.await
|
||||
.map_err(|e| format!("parse response: {}", e))?;
|
||||
|
||||
// Check for function call
|
||||
if let Some((call_id, tool_name, args_json_str)) =
|
||||
Self::extract_function_call(&response_json)
|
||||
{
|
||||
state.iteration += 1;
|
||||
state.total_calls += 1;
|
||||
|
||||
info!(
|
||||
"Tool loop iteration {}: calling {} (call_id: {})",
|
||||
state.iteration, tool_name, call_id
|
||||
);
|
||||
|
||||
// Check combined limit: use minimum of user's max_tool_calls (if set) and safety max_iterations
|
||||
let effective_limit = match max_tool_calls {
|
||||
Some(user_max) => user_max.min(config.max_iterations),
|
||||
None => config.max_iterations,
|
||||
};
|
||||
|
||||
if state.total_calls > effective_limit {
|
||||
if let Some(user_max) = max_tool_calls {
|
||||
if state.total_calls > user_max {
|
||||
warn!("Reached user-specified max_tool_calls limit: {}", user_max);
|
||||
} else {
|
||||
warn!(
|
||||
"Reached safety max_iterations limit: {}",
|
||||
config.max_iterations
|
||||
);
|
||||
}
|
||||
} else {
|
||||
warn!(
|
||||
"Reached safety max_iterations limit: {}",
|
||||
config.max_iterations
|
||||
);
|
||||
}
|
||||
|
||||
return Self::build_incomplete_response(
|
||||
response_json,
|
||||
state,
|
||||
"max_tool_calls",
|
||||
active_mcp,
|
||||
original_body,
|
||||
);
|
||||
}
|
||||
|
||||
// Execute tool
|
||||
let call_result =
|
||||
Self::execute_mcp_call(active_mcp, &tool_name, &args_json_str).await;
|
||||
|
||||
let output_str = match call_result {
|
||||
Ok((_, output)) => output,
|
||||
Err(err) => {
|
||||
warn!("Tool execution failed: {}", err);
|
||||
// Return error as output, let model decide how to proceed
|
||||
json!({ "error": err }).to_string()
|
||||
}
|
||||
};
|
||||
|
||||
// Record the call
|
||||
state.record_call(call_id, tool_name, args_json_str, output_str);
|
||||
|
||||
// Build resume payload
|
||||
current_payload = Self::build_resume_payload(
|
||||
&base_payload,
|
||||
&state.conversation_history,
|
||||
&state.original_input,
|
||||
&tools_json,
|
||||
)?;
|
||||
} else {
|
||||
// No more tool calls, we're done
|
||||
info!(
|
||||
"Tool loop completed: {} iterations, {} total calls",
|
||||
state.iteration, state.total_calls
|
||||
);
|
||||
|
||||
// Inject MCP output items if we executed any tools
|
||||
if state.total_calls > 0 {
|
||||
let server_label = original_body
|
||||
.tools
|
||||
.iter()
|
||||
.find(|t| matches!(t.r#type, ResponseToolType::Mcp))
|
||||
.and_then(|t| t.server_label.as_deref())
|
||||
.unwrap_or("mcp");
|
||||
|
||||
// Build mcp_list_tools item
|
||||
let list_tools_item = Self::build_mcp_list_tools_item(active_mcp, server_label);
|
||||
|
||||
// Insert at beginning of output array
|
||||
if let Some(output_array) = response_json
|
||||
.get_mut("output")
|
||||
.and_then(|v| v.as_array_mut())
|
||||
{
|
||||
output_array.insert(0, list_tools_item);
|
||||
|
||||
// Build mcp_call items using helper function
|
||||
let mcp_call_items = Self::build_executed_mcp_call_items(
|
||||
&state.conversation_history,
|
||||
server_label,
|
||||
);
|
||||
|
||||
// Insert mcp_call items after mcp_list_tools using mutable position
|
||||
let mut insert_pos = 1;
|
||||
for item in mcp_call_items {
|
||||
output_array.insert(insert_pos, item);
|
||||
insert_pos += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Ok(response_json);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate a unique ID for MCP output items (similar to OpenAI format)
|
||||
fn generate_mcp_id(prefix: &str) -> String {
|
||||
use rand::RngCore;
|
||||
@@ -1213,113 +1534,6 @@ impl OpenAIRouter {
|
||||
"server_label": server_label
|
||||
})
|
||||
}
|
||||
|
||||
/// Inject mcp_list_tools and mcp_call items into the response output array
|
||||
fn inject_mcp_output_items(
|
||||
response_json: &mut Value,
|
||||
mcp: &Arc<crate::mcp::McpClientManager>,
|
||||
args: McpOutputItemsArgs,
|
||||
) -> Result<(), String> {
|
||||
let output_array = response_json
|
||||
.get_mut("output")
|
||||
.and_then(|v| v.as_array_mut())
|
||||
.ok_or("missing output array")?;
|
||||
|
||||
// Build MCP output items
|
||||
let list_tools_item = Self::build_mcp_list_tools_item(mcp, args.server_label);
|
||||
let call_item = Self::build_mcp_call_item(
|
||||
args.tool_name,
|
||||
args.args_json,
|
||||
args.output,
|
||||
args.server_label,
|
||||
args.success,
|
||||
args.error,
|
||||
);
|
||||
|
||||
// Find the index of the last message item to insert mcp_call before it
|
||||
let call_insertion_index = output_array
|
||||
.iter()
|
||||
.rposition(|item| item.get("type").and_then(|v| v.as_str()) == Some("message"))
|
||||
.unwrap_or(output_array.len());
|
||||
|
||||
// Insert items in-place for efficiency
|
||||
output_array.insert(call_insertion_index, call_item);
|
||||
output_array.insert(0, list_tools_item);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn resume_with_tool_result(&self, args: ResumeWithToolArgs<'_>) -> Result<Value, String> {
|
||||
let mut payload2 = args.original_payload.clone();
|
||||
let obj = payload2
|
||||
.as_object_mut()
|
||||
.ok_or_else(|| "payload not an object".to_string())?;
|
||||
|
||||
// Build function_call and tool result items per OpenAI Responses spec
|
||||
let user_item = serde_json::json!({
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": args.original_body.input.clone()
|
||||
});
|
||||
// temp system message since currently only support 1 turn of mcp function call
|
||||
let system_item = serde_json::json!({
|
||||
"type": "message",
|
||||
"role": "system",
|
||||
"content": "please resume with the following tool result, and answer user's question directly, don't trigger any more tool calls"
|
||||
});
|
||||
|
||||
let func_item = serde_json::json!({
|
||||
"type": "function_call",
|
||||
"call_id": args.call_id,
|
||||
"name": args.tool_name,
|
||||
"arguments": args.args_json_str
|
||||
});
|
||||
// Build tool result item as function_call_output per OpenAI Responses spec
|
||||
let tool_item = serde_json::json!({
|
||||
"type": "function_call_output",
|
||||
"call_id": args.call_id,
|
||||
"output": args.output_str
|
||||
});
|
||||
|
||||
obj.insert(
|
||||
"input".to_string(),
|
||||
Value::Array(vec![user_item, system_item, func_item, tool_item]),
|
||||
);
|
||||
|
||||
// Ensure non-streaming and no store to upstream
|
||||
obj.insert("stream".to_string(), Value::Bool(false));
|
||||
obj.insert("store".to_string(), Value::Bool(false));
|
||||
let mut req = self.client.post(args.url).json(&payload2);
|
||||
if let Some(headers) = args.headers {
|
||||
req = apply_request_headers(headers, req, true);
|
||||
}
|
||||
let resp = req
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("resume request failed: {}", e))?;
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("resume upstream error {}: {}", status, body));
|
||||
}
|
||||
let mut v = resp
|
||||
.json::<Value>()
|
||||
.await
|
||||
.map_err(|e| format!("parse resume response: {}", e))?;
|
||||
|
||||
if let Some(instr) = &args.original_body.instructions {
|
||||
if let Some(obj) = v.as_object_mut() {
|
||||
obj.entry("instructions")
|
||||
.or_insert(Value::String(instr.clone()));
|
||||
}
|
||||
}
|
||||
// After resume, mask tools as MCP if request used MCP
|
||||
Self::mask_tools_as_mcp(&mut v, args.original_body);
|
||||
if let Some(obj) = v.as_object_mut() {
|
||||
obj.insert("store".to_string(), Value::Bool(args.original_body.store));
|
||||
}
|
||||
Ok(v)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
||||
@@ -252,6 +252,7 @@ fn test_responses_request_creation() {
|
||||
previous_response_id: None,
|
||||
reasoning: Some(ResponseReasoningParam {
|
||||
effort: Some(ReasoningEffort::Medium),
|
||||
summary: None,
|
||||
}),
|
||||
service_tier: ServiceTier::Auto,
|
||||
store: true,
|
||||
@@ -380,6 +381,7 @@ fn test_usage_conversion() {
|
||||
fn test_reasoning_param_default() {
|
||||
let param = ResponseReasoningParam {
|
||||
effort: Some(ReasoningEffort::Medium),
|
||||
summary: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(¶m).unwrap();
|
||||
@@ -403,6 +405,7 @@ fn test_json_serialization() {
|
||||
previous_response_id: None,
|
||||
reasoning: Some(ResponseReasoningParam {
|
||||
effort: Some(ReasoningEffort::High),
|
||||
summary: None,
|
||||
}),
|
||||
service_tier: ServiceTier::Priority,
|
||||
store: false,
|
||||
@@ -437,3 +440,328 @@ fn test_json_serialization() {
|
||||
assert!(parsed.stream);
|
||||
assert_eq!(parsed.tools.len(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multi_turn_loop_with_mcp() {
|
||||
// This test verifies the multi-turn loop functionality:
|
||||
// 1. Initial request with MCP tools
|
||||
// 2. Mock worker returns function_call
|
||||
// 3. Router executes MCP tool and resumes
|
||||
// 4. Mock worker returns final answer
|
||||
// 5. Verify the complete flow worked
|
||||
|
||||
// Start mock MCP server
|
||||
let mut mcp = MockMCPServer::start().await.expect("start mcp");
|
||||
|
||||
// Write a temp MCP config file
|
||||
let mcp_yaml = format!(
|
||||
"servers:\n - name: mock\n protocol: streamable\n url: {}\n",
|
||||
mcp.url()
|
||||
);
|
||||
let dir = tempfile::tempdir().expect("tmpdir");
|
||||
let cfg_path = dir.path().join("mcp.yaml");
|
||||
std::fs::write(&cfg_path, mcp_yaml).expect("write mcp cfg");
|
||||
std::env::set_var("SGLANG_MCP_CONFIG", cfg_path.to_str().unwrap());
|
||||
|
||||
// Start mock OpenAI worker
|
||||
let mut worker = MockWorker::new(MockWorkerConfig {
|
||||
port: 0,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 0.0,
|
||||
});
|
||||
let worker_url = worker.start().await.expect("start worker");
|
||||
|
||||
// Build router config
|
||||
let router_cfg = RouterConfig {
|
||||
mode: RoutingMode::OpenAI {
|
||||
worker_urls: vec![worker_url],
|
||||
},
|
||||
connection_mode: ConnectionMode::Http,
|
||||
policy: PolicyConfig::Random,
|
||||
host: "127.0.0.1".to_string(),
|
||||
port: 0,
|
||||
max_payload_size: 8 * 1024 * 1024,
|
||||
request_timeout_secs: 60,
|
||||
worker_startup_timeout_secs: 5,
|
||||
worker_startup_check_interval_secs: 1,
|
||||
dp_aware: false,
|
||||
api_key: None,
|
||||
discovery: None,
|
||||
metrics: None,
|
||||
log_dir: None,
|
||||
log_level: Some("info".to_string()),
|
||||
request_id_headers: None,
|
||||
max_concurrent_requests: 32,
|
||||
queue_size: 0,
|
||||
queue_timeout_secs: 5,
|
||||
rate_limit_tokens_per_second: None,
|
||||
cors_allowed_origins: vec![],
|
||||
retry: RetryConfig::default(),
|
||||
circuit_breaker: CircuitBreakerConfig::default(),
|
||||
disable_retries: false,
|
||||
disable_circuit_breaker: false,
|
||||
health_check: HealthCheckConfig::default(),
|
||||
enable_igw: false,
|
||||
model_path: None,
|
||||
tokenizer_path: None,
|
||||
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
|
||||
oracle: None,
|
||||
};
|
||||
|
||||
let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 64, None).expect("ctx");
|
||||
let router = RouterFactory::create_router(&Arc::new(ctx))
|
||||
.await
|
||||
.expect("router");
|
||||
|
||||
// Build request with MCP tools
|
||||
let req = ResponsesRequest {
|
||||
background: false,
|
||||
include: None,
|
||||
input: ResponseInput::Text("search for SGLang".to_string()),
|
||||
instructions: Some("Be helpful".to_string()),
|
||||
max_output_tokens: Some(128),
|
||||
max_tool_calls: None, // No limit - test unlimited
|
||||
metadata: None,
|
||||
model: Some("mock-model".to_string()),
|
||||
parallel_tool_calls: true,
|
||||
previous_response_id: None,
|
||||
reasoning: None,
|
||||
service_tier: ServiceTier::Auto,
|
||||
store: true,
|
||||
stream: false,
|
||||
temperature: Some(0.7),
|
||||
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto),
|
||||
tools: vec![ResponseTool {
|
||||
r#type: ResponseToolType::Mcp,
|
||||
server_url: Some(mcp.url()),
|
||||
server_label: Some("mock".to_string()),
|
||||
server_description: Some("Mock MCP server for testing".to_string()),
|
||||
require_approval: Some("never".to_string()),
|
||||
..Default::default()
|
||||
}],
|
||||
top_logprobs: 0,
|
||||
top_p: Some(1.0),
|
||||
truncation: Truncation::Disabled,
|
||||
user: None,
|
||||
request_id: "resp_multi_turn_test".to_string(),
|
||||
priority: 0,
|
||||
frequency_penalty: 0.0,
|
||||
presence_penalty: 0.0,
|
||||
stop: None,
|
||||
top_k: 50,
|
||||
min_p: 0.0,
|
||||
repetition_penalty: 1.0,
|
||||
};
|
||||
|
||||
// Execute the request (this should trigger the multi-turn loop)
|
||||
let response = router.route_responses(None, &req, None).await;
|
||||
|
||||
// Check status
|
||||
assert_eq!(
|
||||
response.status(),
|
||||
axum::http::StatusCode::OK,
|
||||
"Request should succeed"
|
||||
);
|
||||
|
||||
// Read the response body
|
||||
use axum::body::to_bytes;
|
||||
let response_body = response.into_body();
|
||||
let body_bytes = to_bytes(response_body, usize::MAX).await.unwrap();
|
||||
let response_json: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
|
||||
|
||||
println!(
|
||||
"Multi-turn response: {}",
|
||||
serde_json::to_string_pretty(&response_json).unwrap()
|
||||
);
|
||||
|
||||
// Verify the response structure
|
||||
assert_eq!(response_json["object"], "response");
|
||||
assert_eq!(response_json["status"], "completed");
|
||||
// Note: mock worker generates its own ID, so we just verify it exists
|
||||
assert!(
|
||||
response_json["id"].is_string(),
|
||||
"Response should have an id"
|
||||
);
|
||||
|
||||
// Check that output contains final message
|
||||
let output = response_json["output"]
|
||||
.as_array()
|
||||
.expect("output should be array");
|
||||
assert!(!output.is_empty(), "output should not be empty");
|
||||
|
||||
// Find the final message with text
|
||||
let has_final_text = output.iter().any(|item| {
|
||||
item.get("type")
|
||||
.and_then(|t| t.as_str())
|
||||
.map(|t| t == "message")
|
||||
.unwrap_or(false)
|
||||
&& item
|
||||
.get("content")
|
||||
.and_then(|c| c.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter().any(|part| {
|
||||
part.get("type")
|
||||
.and_then(|t| t.as_str())
|
||||
.map(|t| t == "output_text")
|
||||
.unwrap_or(false)
|
||||
})
|
||||
})
|
||||
.unwrap_or(false)
|
||||
});
|
||||
|
||||
assert!(has_final_text, "Should have final text output");
|
||||
|
||||
// Verify tools are masked back to MCP format
|
||||
let tools = response_json["tools"]
|
||||
.as_array()
|
||||
.expect("tools should be array");
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0]["type"], "mcp");
|
||||
assert_eq!(tools[0]["server_label"], "mock");
|
||||
|
||||
// Clean up
|
||||
std::env::remove_var("SGLANG_MCP_CONFIG");
|
||||
worker.stop().await;
|
||||
mcp.stop().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_max_tool_calls_limit() {
|
||||
// This test verifies that max_tool_calls is respected
|
||||
// Note: The mock worker returns a final answer after one tool call,
|
||||
// so with max_tool_calls=1, it completes normally (doesn't exceed the limit)
|
||||
|
||||
let mut mcp = MockMCPServer::start().await.expect("start mcp");
|
||||
let mcp_yaml = format!(
|
||||
"servers:\n - name: mock\n protocol: streamable\n url: {}\n",
|
||||
mcp.url()
|
||||
);
|
||||
let dir = tempfile::tempdir().expect("tmpdir");
|
||||
let cfg_path = dir.path().join("mcp.yaml");
|
||||
std::fs::write(&cfg_path, mcp_yaml).expect("write mcp cfg");
|
||||
std::env::set_var("SGLANG_MCP_CONFIG", cfg_path.to_str().unwrap());
|
||||
|
||||
let mut worker = MockWorker::new(MockWorkerConfig {
|
||||
port: 0,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 0.0,
|
||||
});
|
||||
let worker_url = worker.start().await.expect("start worker");
|
||||
|
||||
let router_cfg = RouterConfig {
|
||||
mode: RoutingMode::OpenAI {
|
||||
worker_urls: vec![worker_url],
|
||||
},
|
||||
connection_mode: ConnectionMode::Http,
|
||||
policy: PolicyConfig::Random,
|
||||
host: "127.0.0.1".to_string(),
|
||||
port: 0,
|
||||
max_payload_size: 8 * 1024 * 1024,
|
||||
request_timeout_secs: 60,
|
||||
worker_startup_timeout_secs: 5,
|
||||
worker_startup_check_interval_secs: 1,
|
||||
dp_aware: false,
|
||||
api_key: None,
|
||||
discovery: None,
|
||||
metrics: None,
|
||||
log_dir: None,
|
||||
log_level: Some("info".to_string()),
|
||||
request_id_headers: None,
|
||||
max_concurrent_requests: 32,
|
||||
queue_size: 0,
|
||||
queue_timeout_secs: 5,
|
||||
rate_limit_tokens_per_second: None,
|
||||
cors_allowed_origins: vec![],
|
||||
retry: RetryConfig::default(),
|
||||
circuit_breaker: CircuitBreakerConfig::default(),
|
||||
disable_retries: false,
|
||||
disable_circuit_breaker: false,
|
||||
health_check: HealthCheckConfig::default(),
|
||||
enable_igw: false,
|
||||
model_path: None,
|
||||
tokenizer_path: None,
|
||||
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
|
||||
oracle: None,
|
||||
};
|
||||
|
||||
let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 64, None).expect("ctx");
|
||||
let router = RouterFactory::create_router(&Arc::new(ctx))
|
||||
.await
|
||||
.expect("router");
|
||||
|
||||
let req = ResponsesRequest {
|
||||
background: false,
|
||||
include: None,
|
||||
input: ResponseInput::Text("test max calls".to_string()),
|
||||
instructions: None,
|
||||
max_output_tokens: Some(128),
|
||||
max_tool_calls: Some(1), // Limit to 1 call
|
||||
metadata: None,
|
||||
model: Some("mock-model".to_string()),
|
||||
parallel_tool_calls: true,
|
||||
previous_response_id: None,
|
||||
reasoning: None,
|
||||
service_tier: ServiceTier::Auto,
|
||||
store: false,
|
||||
stream: false,
|
||||
temperature: Some(0.7),
|
||||
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto),
|
||||
tools: vec![ResponseTool {
|
||||
r#type: ResponseToolType::Mcp,
|
||||
server_url: Some(mcp.url()),
|
||||
server_label: Some("mock".to_string()),
|
||||
..Default::default()
|
||||
}],
|
||||
top_logprobs: 0,
|
||||
top_p: Some(1.0),
|
||||
truncation: Truncation::Disabled,
|
||||
user: None,
|
||||
request_id: "resp_max_calls_test".to_string(),
|
||||
priority: 0,
|
||||
frequency_penalty: 0.0,
|
||||
presence_penalty: 0.0,
|
||||
stop: None,
|
||||
top_k: 50,
|
||||
min_p: 0.0,
|
||||
repetition_penalty: 1.0,
|
||||
};
|
||||
|
||||
let response = router.route_responses(None, &req, None).await;
|
||||
assert_eq!(response.status(), axum::http::StatusCode::OK);
|
||||
|
||||
use axum::body::to_bytes;
|
||||
let response_body = response.into_body();
|
||||
let body_bytes = to_bytes(response_body, usize::MAX).await.unwrap();
|
||||
let response_json: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
|
||||
|
||||
println!(
|
||||
"Max calls response: {}",
|
||||
serde_json::to_string_pretty(&response_json).unwrap()
|
||||
);
|
||||
|
||||
// With max_tool_calls=1, the mock returns a final answer after 1 call
|
||||
// So it completes normally without exceeding the limit
|
||||
assert_eq!(response_json["status"], "completed");
|
||||
|
||||
// Verify the basic response structure
|
||||
assert!(response_json["id"].is_string());
|
||||
assert_eq!(response_json["object"], "response");
|
||||
|
||||
// The response should have tools masked back to MCP format
|
||||
let tools = response_json["tools"]
|
||||
.as_array()
|
||||
.expect("tools should be array");
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0]["type"], "mcp");
|
||||
|
||||
// Note: To test actual limit exceeding, we would need a mock that keeps
|
||||
// calling tools indefinitely, which would hit max_iterations (safety limit)
|
||||
|
||||
std::env::remove_var("SGLANG_MCP_CONFIG");
|
||||
worker.stop().await;
|
||||
mcp.stop().await;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user