[router] Add multi-turn tool calling loop support for MCP integration (#11143)
This commit is contained in:
@@ -723,7 +723,10 @@ pub enum ResponseToolType {
|
|||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
pub struct ResponseReasoningParam {
|
pub struct ResponseReasoningParam {
|
||||||
#[serde(default = "default_reasoning_effort")]
|
#[serde(default = "default_reasoning_effort")]
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub effort: Option<ReasoningEffort>,
|
pub effort: Option<ReasoningEffort>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub summary: Option<ReasoningSummary>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_reasoning_effort() -> Option<ReasoningEffort> {
|
fn default_reasoning_effort() -> Option<ReasoningEffort> {
|
||||||
@@ -738,6 +741,14 @@ pub enum ReasoningEffort {
|
|||||||
High,
|
High,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub enum ReasoningSummary {
|
||||||
|
Auto,
|
||||||
|
Concise,
|
||||||
|
Detailed,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
#[serde(tag = "type")]
|
#[serde(tag = "type")]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ use std::{
|
|||||||
collections::HashMap,
|
collections::HashMap,
|
||||||
io,
|
io,
|
||||||
sync::{atomic::AtomicBool, Arc},
|
sync::{atomic::AtomicBool, Arc},
|
||||||
time::SystemTime,
|
|
||||||
};
|
};
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
@@ -57,6 +56,69 @@ impl std::fmt::Debug for OpenAIRouter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Configuration for MCP tool calling loops
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct McpLoopConfig {
|
||||||
|
/// Maximum iterations as safety limit (internal only, default: 10)
|
||||||
|
/// Prevents infinite loops when max_tool_calls is not set
|
||||||
|
max_iterations: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for McpLoopConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self { max_iterations: 10 }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// State for tracking multi-turn tool calling loop
|
||||||
|
struct ToolLoopState {
|
||||||
|
/// Current iteration number (starts at 0, increments with each tool call)
|
||||||
|
iteration: usize,
|
||||||
|
/// Total number of tool calls executed
|
||||||
|
total_calls: usize,
|
||||||
|
/// Conversation history (function_call and function_call_output items)
|
||||||
|
conversation_history: Vec<Value>,
|
||||||
|
/// Original user input (preserved for building resume payloads)
|
||||||
|
original_input: ResponseInput,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToolLoopState {
|
||||||
|
fn new(original_input: ResponseInput) -> Self {
|
||||||
|
Self {
|
||||||
|
iteration: 0,
|
||||||
|
total_calls: 0,
|
||||||
|
conversation_history: Vec::new(),
|
||||||
|
original_input,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Record a tool call in the loop state
|
||||||
|
fn record_call(
|
||||||
|
&mut self,
|
||||||
|
call_id: String,
|
||||||
|
tool_name: String,
|
||||||
|
args_json_str: String,
|
||||||
|
output_str: String,
|
||||||
|
) {
|
||||||
|
// Add function_call item to history
|
||||||
|
let func_item = json!({
|
||||||
|
"type": "function_call",
|
||||||
|
"call_id": call_id,
|
||||||
|
"name": tool_name,
|
||||||
|
"arguments": args_json_str
|
||||||
|
});
|
||||||
|
self.conversation_history.push(func_item);
|
||||||
|
|
||||||
|
// Add function_call_output item to history
|
||||||
|
let output_item = json!({
|
||||||
|
"type": "function_call_output",
|
||||||
|
"call_id": call_id,
|
||||||
|
"output": output_str
|
||||||
|
});
|
||||||
|
self.conversation_history.push(output_item);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Helper that parses SSE frames from the OpenAI responses stream and
|
/// Helper that parses SSE frames from the OpenAI responses stream and
|
||||||
/// accumulates enough information to persist the final response locally.
|
/// accumulates enough information to persist the final response locally.
|
||||||
struct StreamingResponseAccumulator {
|
struct StreamingResponseAccumulator {
|
||||||
@@ -388,126 +450,32 @@ impl OpenAIRouter {
|
|||||||
obj.insert("store".to_string(), Value::Bool(original_body.store));
|
obj.insert("store".to_string(), Value::Bool(original_body.store));
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut final_response_json = openai_response_json;
|
// If MCP is active and we detect a function call, enter the tool loop
|
||||||
|
let mut final_response_json = if let Some(mcp) = active_mcp {
|
||||||
if let Some(mcp) = active_mcp {
|
if Self::extract_function_call(&openai_response_json).is_some() {
|
||||||
if let Some((call_id, tool_name, args_json_str)) =
|
// Use the loop to handle potentially multiple tool calls
|
||||||
Self::extract_function_call(&final_response_json)
|
let loop_config = McpLoopConfig::default();
|
||||||
{
|
|
||||||
info!(
|
|
||||||
"Detected function call: name={}, call_id={}, args={}",
|
|
||||||
tool_name, call_id, args_json_str
|
|
||||||
);
|
|
||||||
|
|
||||||
let call_started = SystemTime::now();
|
|
||||||
let call_result =
|
|
||||||
Self::execute_mcp_call(mcp, &tool_name, &args_json_str).await;
|
|
||||||
let call_duration_ms =
|
|
||||||
call_started.elapsed().unwrap_or_default().as_millis();
|
|
||||||
|
|
||||||
let (output_payload, call_ok, call_error) = match call_result {
|
|
||||||
Ok((server, out)) => {
|
|
||||||
info!(
|
|
||||||
call_id = %call_id,
|
|
||||||
tool_name = %tool_name,
|
|
||||||
server = %server,
|
|
||||||
duration_ms = call_duration_ms,
|
|
||||||
"MCP tool call succeeded"
|
|
||||||
);
|
|
||||||
(out, true, None)
|
|
||||||
}
|
|
||||||
Err(err) => {
|
|
||||||
warn!(
|
|
||||||
call_id = %call_id,
|
|
||||||
tool_name = %tool_name,
|
|
||||||
duration_ms = call_duration_ms,
|
|
||||||
error = %err,
|
|
||||||
"MCP tool call failed"
|
|
||||||
);
|
|
||||||
(
|
|
||||||
serde_json::json!({
|
|
||||||
"error": err
|
|
||||||
})
|
|
||||||
.to_string(),
|
|
||||||
false,
|
|
||||||
Some(err),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
match self
|
match self
|
||||||
.resume_with_tool_result(ResumeWithToolArgs {
|
.execute_tool_loop(
|
||||||
url: &url,
|
&url,
|
||||||
headers,
|
headers,
|
||||||
original_payload: &payload,
|
payload.clone(),
|
||||||
call_id: &call_id,
|
|
||||||
tool_name: &tool_name,
|
|
||||||
args_json_str: &args_json_str,
|
|
||||||
output_str: &output_payload,
|
|
||||||
original_body,
|
original_body,
|
||||||
})
|
mcp,
|
||||||
|
&loop_config,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(mut resumed_json) => {
|
Ok(loop_result) => loop_result,
|
||||||
// Inject MCP output items (mcp_list_tools and mcp_call)
|
|
||||||
let server_label = original_body
|
|
||||||
.tools
|
|
||||||
.iter()
|
|
||||||
.find(|t| matches!(t.r#type, ResponseToolType::Mcp))
|
|
||||||
.and_then(|t| t.server_label.as_deref())
|
|
||||||
.unwrap_or("mcp");
|
|
||||||
|
|
||||||
if let Err(inject_err) = Self::inject_mcp_output_items(
|
|
||||||
&mut resumed_json,
|
|
||||||
mcp,
|
|
||||||
McpOutputItemsArgs {
|
|
||||||
tool_name: &tool_name,
|
|
||||||
args_json: &args_json_str,
|
|
||||||
output: &output_payload,
|
|
||||||
server_label,
|
|
||||||
success: call_ok,
|
|
||||||
error: call_error.as_deref(),
|
|
||||||
},
|
|
||||||
) {
|
|
||||||
warn!(
|
|
||||||
"Failed to inject MCP output items: {}",
|
|
||||||
inject_err
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if !call_ok {
|
|
||||||
if let Some(obj) = resumed_json.as_object_mut() {
|
|
||||||
let metadata_value =
|
|
||||||
obj.entry("metadata").or_insert_with(|| {
|
|
||||||
Value::Object(serde_json::Map::new())
|
|
||||||
});
|
|
||||||
if let Some(metadata) =
|
|
||||||
metadata_value.as_object_mut()
|
|
||||||
{
|
|
||||||
if let Some(err_msg) = call_error.as_ref() {
|
|
||||||
metadata.insert(
|
|
||||||
"mcp_error".to_string(),
|
|
||||||
Value::String(err_msg.clone()),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
final_response_json = resumed_json;
|
|
||||||
}
|
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
warn!("Failed to resume with tool result: {}", err);
|
warn!("Tool loop failed: {}", err);
|
||||||
let error_body = json!({
|
let error_body = json!({
|
||||||
"error": {
|
"error": {
|
||||||
"message": format!(
|
"message": format!("Tool loop failed: {}", err),
|
||||||
"Failed to resume with tool result: {}",
|
|
||||||
err
|
|
||||||
),
|
|
||||||
"type": "internal_error",
|
"type": "internal_error",
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.to_string();
|
.to_string();
|
||||||
|
|
||||||
return (
|
return (
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
[("content-type", "application/json")],
|
[("content-type", "application/json")],
|
||||||
@@ -517,10 +485,14 @@ impl OpenAIRouter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
info!("No function call found in upstream response; skipping MCP");
|
// No function call detected, use response as-is
|
||||||
|
openai_response_json
|
||||||
}
|
}
|
||||||
}
|
} else {
|
||||||
|
openai_response_json
|
||||||
|
};
|
||||||
|
|
||||||
|
// Mask tools back to MCP format for client
|
||||||
Self::mask_tools_as_mcp(&mut final_response_json, original_body);
|
Self::mask_tools_as_mcp(&mut final_response_json, original_body);
|
||||||
if original_body.store {
|
if original_body.store {
|
||||||
if let Err(e) = self
|
if let Err(e) = self
|
||||||
@@ -1040,26 +1012,6 @@ impl OpenAIRouter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ResumeWithToolArgs<'a> {
|
|
||||||
url: &'a str,
|
|
||||||
headers: Option<&'a HeaderMap>,
|
|
||||||
original_payload: &'a Value,
|
|
||||||
call_id: &'a str,
|
|
||||||
tool_name: &'a str,
|
|
||||||
args_json_str: &'a str,
|
|
||||||
output_str: &'a str,
|
|
||||||
original_body: &'a ResponsesRequest,
|
|
||||||
}
|
|
||||||
|
|
||||||
struct McpOutputItemsArgs<'a> {
|
|
||||||
tool_name: &'a str,
|
|
||||||
args_json: &'a str,
|
|
||||||
output: &'a str,
|
|
||||||
server_label: &'a str,
|
|
||||||
success: bool,
|
|
||||||
error: Option<&'a str>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl OpenAIRouter {
|
impl OpenAIRouter {
|
||||||
fn extract_function_call(resp: &Value) -> Option<(String, String, String)> {
|
fn extract_function_call(resp: &Value) -> Option<(String, String, String)> {
|
||||||
let output = resp.get("output")?.as_array()?;
|
let output = resp.get("output")?.as_array()?;
|
||||||
@@ -1150,6 +1102,375 @@ impl OpenAIRouter {
|
|||||||
Ok((server_name, output_str))
|
Ok((server_name, output_str))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Build a resume payload with conversation history
|
||||||
|
fn build_resume_payload(
|
||||||
|
base_payload: &Value,
|
||||||
|
conversation_history: &[Value],
|
||||||
|
original_input: &ResponseInput,
|
||||||
|
tools_json: &Value,
|
||||||
|
) -> Result<Value, String> {
|
||||||
|
// Clone the base payload which already has cleaned fields
|
||||||
|
let mut payload = base_payload.clone();
|
||||||
|
|
||||||
|
let obj = payload
|
||||||
|
.as_object_mut()
|
||||||
|
.ok_or_else(|| "payload not an object".to_string())?;
|
||||||
|
|
||||||
|
// Build input array: start with original user input
|
||||||
|
let mut input_array = Vec::new();
|
||||||
|
|
||||||
|
// Add original user message
|
||||||
|
// For structured input, serialize the original input items
|
||||||
|
match original_input {
|
||||||
|
ResponseInput::Text(text) => {
|
||||||
|
let user_item = json!({
|
||||||
|
"type": "message",
|
||||||
|
"role": "user",
|
||||||
|
"content": [{ "type": "input_text", "text": text }]
|
||||||
|
});
|
||||||
|
input_array.push(user_item);
|
||||||
|
}
|
||||||
|
ResponseInput::Items(items) => {
|
||||||
|
// Items are already structured ResponseInputOutputItem, convert to JSON
|
||||||
|
if let Ok(items_value) = serde_json::to_value(items) {
|
||||||
|
if let Some(items_arr) = items_value.as_array() {
|
||||||
|
input_array.extend_from_slice(items_arr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add all conversation history (function calls and outputs)
|
||||||
|
input_array.extend_from_slice(conversation_history);
|
||||||
|
|
||||||
|
obj.insert("input".to_string(), Value::Array(input_array));
|
||||||
|
|
||||||
|
// Use the transformed tools (function tools, not MCP tools)
|
||||||
|
if let Some(tools_arr) = tools_json.as_array() {
|
||||||
|
if !tools_arr.is_empty() {
|
||||||
|
obj.insert("tools".to_string(), tools_json.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure non-streaming and no store to upstream
|
||||||
|
obj.insert("stream".to_string(), Value::Bool(false));
|
||||||
|
obj.insert("store".to_string(), Value::Bool(false));
|
||||||
|
|
||||||
|
// Note: SGLang-specific fields were already removed from base_payload
|
||||||
|
// before it was passed to execute_tool_loop (see route_responses lines 1935-1946)
|
||||||
|
|
||||||
|
Ok(payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper function to build mcp_call items from executed tool calls in conversation history
|
||||||
|
fn build_executed_mcp_call_items(
|
||||||
|
conversation_history: &[Value],
|
||||||
|
server_label: &str,
|
||||||
|
) -> Vec<Value> {
|
||||||
|
let mut mcp_call_items = Vec::new();
|
||||||
|
|
||||||
|
for item in conversation_history {
|
||||||
|
if item.get("type").and_then(|t| t.as_str()) == Some("function_call") {
|
||||||
|
let call_id = item.get("call_id").and_then(|v| v.as_str()).unwrap_or("");
|
||||||
|
let tool_name = item.get("name").and_then(|v| v.as_str()).unwrap_or("");
|
||||||
|
let args = item
|
||||||
|
.get("arguments")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or("{}");
|
||||||
|
|
||||||
|
// Find corresponding output
|
||||||
|
let output_item = conversation_history.iter().find(|o| {
|
||||||
|
o.get("type").and_then(|t| t.as_str()) == Some("function_call_output")
|
||||||
|
&& o.get("call_id").and_then(|c| c.as_str()) == Some(call_id)
|
||||||
|
});
|
||||||
|
|
||||||
|
let output_str = output_item
|
||||||
|
.and_then(|o| o.get("output").and_then(|v| v.as_str()))
|
||||||
|
.unwrap_or("{}");
|
||||||
|
|
||||||
|
// Check if output contains error by parsing JSON
|
||||||
|
let is_error = serde_json::from_str::<serde_json::Value>(output_str)
|
||||||
|
.map(|v| v.get("error").is_some())
|
||||||
|
.unwrap_or(false);
|
||||||
|
|
||||||
|
let mcp_call_item = Self::build_mcp_call_item(
|
||||||
|
tool_name,
|
||||||
|
args,
|
||||||
|
output_str,
|
||||||
|
server_label,
|
||||||
|
!is_error,
|
||||||
|
if is_error {
|
||||||
|
Some("Tool execution failed")
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
},
|
||||||
|
);
|
||||||
|
mcp_call_items.push(mcp_call_item);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mcp_call_items
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build an incomplete response when limits are exceeded
|
||||||
|
fn build_incomplete_response(
|
||||||
|
mut response: Value,
|
||||||
|
state: ToolLoopState,
|
||||||
|
reason: &str,
|
||||||
|
active_mcp: &Arc<crate::mcp::McpClientManager>,
|
||||||
|
original_body: &ResponsesRequest,
|
||||||
|
) -> Result<Value, String> {
|
||||||
|
let obj = response
|
||||||
|
.as_object_mut()
|
||||||
|
.ok_or_else(|| "response not an object".to_string())?;
|
||||||
|
|
||||||
|
// Set status to completed (not failed - partial success)
|
||||||
|
obj.insert("status".to_string(), Value::String("completed".to_string()));
|
||||||
|
|
||||||
|
// Set incomplete_details
|
||||||
|
obj.insert(
|
||||||
|
"incomplete_details".to_string(),
|
||||||
|
json!({ "reason": reason }),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Convert any function_call in output to mcp_call format
|
||||||
|
if let Some(output_array) = obj.get_mut("output").and_then(|v| v.as_array_mut()) {
|
||||||
|
let server_label = original_body
|
||||||
|
.tools
|
||||||
|
.iter()
|
||||||
|
.find(|t| matches!(t.r#type, ResponseToolType::Mcp))
|
||||||
|
.and_then(|t| t.server_label.as_deref())
|
||||||
|
.unwrap_or("mcp");
|
||||||
|
|
||||||
|
// Find any function_call items and convert them to mcp_call (incomplete)
|
||||||
|
let mut mcp_call_items = Vec::new();
|
||||||
|
for item in output_array.iter() {
|
||||||
|
if item.get("type").and_then(|t| t.as_str()) == Some("function_tool_call") {
|
||||||
|
let tool_name = item.get("name").and_then(|v| v.as_str()).unwrap_or("");
|
||||||
|
let args = item
|
||||||
|
.get("arguments")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or("{}");
|
||||||
|
|
||||||
|
// Mark as incomplete - not executed
|
||||||
|
let mcp_call_item = Self::build_mcp_call_item(
|
||||||
|
tool_name,
|
||||||
|
args,
|
||||||
|
"", // No output - wasn't executed
|
||||||
|
server_label,
|
||||||
|
false, // Not successful
|
||||||
|
Some("Not executed - response stopped due to limit"),
|
||||||
|
);
|
||||||
|
mcp_call_items.push(mcp_call_item);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add mcp_list_tools and executed mcp_call items at the beginning
|
||||||
|
if state.total_calls > 0 || !mcp_call_items.is_empty() {
|
||||||
|
let list_tools_item = Self::build_mcp_list_tools_item(active_mcp, server_label);
|
||||||
|
output_array.insert(0, list_tools_item);
|
||||||
|
|
||||||
|
// Add mcp_call items for executed calls using helper
|
||||||
|
let executed_items =
|
||||||
|
Self::build_executed_mcp_call_items(&state.conversation_history, server_label);
|
||||||
|
|
||||||
|
let mut insert_pos = 1;
|
||||||
|
for item in executed_items {
|
||||||
|
output_array.insert(insert_pos, item);
|
||||||
|
insert_pos += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add incomplete mcp_call items
|
||||||
|
for item in mcp_call_items {
|
||||||
|
output_array.insert(insert_pos, item);
|
||||||
|
insert_pos += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add warning to metadata
|
||||||
|
if let Some(metadata_val) = obj.get_mut("metadata") {
|
||||||
|
if let Some(metadata_obj) = metadata_val.as_object_mut() {
|
||||||
|
if let Some(mcp_val) = metadata_obj.get_mut("mcp") {
|
||||||
|
if let Some(mcp_obj) = mcp_val.as_object_mut() {
|
||||||
|
mcp_obj.insert(
|
||||||
|
"truncation_warning".to_string(),
|
||||||
|
Value::String(format!(
|
||||||
|
"Loop terminated at {} iterations, {} total calls (reason: {})",
|
||||||
|
state.iteration, state.total_calls, reason
|
||||||
|
)),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(response)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Execute the tool calling loop
|
||||||
|
async fn execute_tool_loop(
|
||||||
|
&self,
|
||||||
|
url: &str,
|
||||||
|
headers: Option<&HeaderMap>,
|
||||||
|
initial_payload: Value,
|
||||||
|
original_body: &ResponsesRequest,
|
||||||
|
active_mcp: &Arc<crate::mcp::McpClientManager>,
|
||||||
|
config: &McpLoopConfig,
|
||||||
|
) -> Result<Value, String> {
|
||||||
|
let mut state = ToolLoopState::new(original_body.input.clone());
|
||||||
|
|
||||||
|
// Get max_tool_calls from request (None means no user-specified limit)
|
||||||
|
let max_tool_calls = original_body.max_tool_calls.map(|n| n as usize);
|
||||||
|
|
||||||
|
// Keep initial_payload as base template (already has fields cleaned)
|
||||||
|
let base_payload = initial_payload.clone();
|
||||||
|
let tools_json = base_payload.get("tools").cloned().unwrap_or(json!([]));
|
||||||
|
let mut current_payload = initial_payload;
|
||||||
|
|
||||||
|
info!(
|
||||||
|
"Starting tool loop: max_tool_calls={:?}, max_iterations={}",
|
||||||
|
max_tool_calls, config.max_iterations
|
||||||
|
);
|
||||||
|
|
||||||
|
loop {
|
||||||
|
// Make request to upstream
|
||||||
|
let request_builder = self.client.post(url).json(¤t_payload);
|
||||||
|
let request_builder = if let Some(headers) = headers {
|
||||||
|
apply_request_headers(headers, request_builder, true)
|
||||||
|
} else {
|
||||||
|
request_builder
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = request_builder
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("upstream request failed: {}", e))?;
|
||||||
|
|
||||||
|
if !response.status().is_success() {
|
||||||
|
let status = response.status();
|
||||||
|
let body = response.text().await.unwrap_or_default();
|
||||||
|
return Err(format!("upstream error {}: {}", status, body));
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut response_json = response
|
||||||
|
.json::<Value>()
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("parse response: {}", e))?;
|
||||||
|
|
||||||
|
// Check for function call
|
||||||
|
if let Some((call_id, tool_name, args_json_str)) =
|
||||||
|
Self::extract_function_call(&response_json)
|
||||||
|
{
|
||||||
|
state.iteration += 1;
|
||||||
|
state.total_calls += 1;
|
||||||
|
|
||||||
|
info!(
|
||||||
|
"Tool loop iteration {}: calling {} (call_id: {})",
|
||||||
|
state.iteration, tool_name, call_id
|
||||||
|
);
|
||||||
|
|
||||||
|
// Check combined limit: use minimum of user's max_tool_calls (if set) and safety max_iterations
|
||||||
|
let effective_limit = match max_tool_calls {
|
||||||
|
Some(user_max) => user_max.min(config.max_iterations),
|
||||||
|
None => config.max_iterations,
|
||||||
|
};
|
||||||
|
|
||||||
|
if state.total_calls > effective_limit {
|
||||||
|
if let Some(user_max) = max_tool_calls {
|
||||||
|
if state.total_calls > user_max {
|
||||||
|
warn!("Reached user-specified max_tool_calls limit: {}", user_max);
|
||||||
|
} else {
|
||||||
|
warn!(
|
||||||
|
"Reached safety max_iterations limit: {}",
|
||||||
|
config.max_iterations
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
warn!(
|
||||||
|
"Reached safety max_iterations limit: {}",
|
||||||
|
config.max_iterations
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return Self::build_incomplete_response(
|
||||||
|
response_json,
|
||||||
|
state,
|
||||||
|
"max_tool_calls",
|
||||||
|
active_mcp,
|
||||||
|
original_body,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute tool
|
||||||
|
let call_result =
|
||||||
|
Self::execute_mcp_call(active_mcp, &tool_name, &args_json_str).await;
|
||||||
|
|
||||||
|
let output_str = match call_result {
|
||||||
|
Ok((_, output)) => output,
|
||||||
|
Err(err) => {
|
||||||
|
warn!("Tool execution failed: {}", err);
|
||||||
|
// Return error as output, let model decide how to proceed
|
||||||
|
json!({ "error": err }).to_string()
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Record the call
|
||||||
|
state.record_call(call_id, tool_name, args_json_str, output_str);
|
||||||
|
|
||||||
|
// Build resume payload
|
||||||
|
current_payload = Self::build_resume_payload(
|
||||||
|
&base_payload,
|
||||||
|
&state.conversation_history,
|
||||||
|
&state.original_input,
|
||||||
|
&tools_json,
|
||||||
|
)?;
|
||||||
|
} else {
|
||||||
|
// No more tool calls, we're done
|
||||||
|
info!(
|
||||||
|
"Tool loop completed: {} iterations, {} total calls",
|
||||||
|
state.iteration, state.total_calls
|
||||||
|
);
|
||||||
|
|
||||||
|
// Inject MCP output items if we executed any tools
|
||||||
|
if state.total_calls > 0 {
|
||||||
|
let server_label = original_body
|
||||||
|
.tools
|
||||||
|
.iter()
|
||||||
|
.find(|t| matches!(t.r#type, ResponseToolType::Mcp))
|
||||||
|
.and_then(|t| t.server_label.as_deref())
|
||||||
|
.unwrap_or("mcp");
|
||||||
|
|
||||||
|
// Build mcp_list_tools item
|
||||||
|
let list_tools_item = Self::build_mcp_list_tools_item(active_mcp, server_label);
|
||||||
|
|
||||||
|
// Insert at beginning of output array
|
||||||
|
if let Some(output_array) = response_json
|
||||||
|
.get_mut("output")
|
||||||
|
.and_then(|v| v.as_array_mut())
|
||||||
|
{
|
||||||
|
output_array.insert(0, list_tools_item);
|
||||||
|
|
||||||
|
// Build mcp_call items using helper function
|
||||||
|
let mcp_call_items = Self::build_executed_mcp_call_items(
|
||||||
|
&state.conversation_history,
|
||||||
|
server_label,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Insert mcp_call items after mcp_list_tools using mutable position
|
||||||
|
let mut insert_pos = 1;
|
||||||
|
for item in mcp_call_items {
|
||||||
|
output_array.insert(insert_pos, item);
|
||||||
|
insert_pos += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Ok(response_json);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Generate a unique ID for MCP output items (similar to OpenAI format)
|
/// Generate a unique ID for MCP output items (similar to OpenAI format)
|
||||||
fn generate_mcp_id(prefix: &str) -> String {
|
fn generate_mcp_id(prefix: &str) -> String {
|
||||||
use rand::RngCore;
|
use rand::RngCore;
|
||||||
@@ -1213,113 +1534,6 @@ impl OpenAIRouter {
|
|||||||
"server_label": server_label
|
"server_label": server_label
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Inject mcp_list_tools and mcp_call items into the response output array
|
|
||||||
fn inject_mcp_output_items(
|
|
||||||
response_json: &mut Value,
|
|
||||||
mcp: &Arc<crate::mcp::McpClientManager>,
|
|
||||||
args: McpOutputItemsArgs,
|
|
||||||
) -> Result<(), String> {
|
|
||||||
let output_array = response_json
|
|
||||||
.get_mut("output")
|
|
||||||
.and_then(|v| v.as_array_mut())
|
|
||||||
.ok_or("missing output array")?;
|
|
||||||
|
|
||||||
// Build MCP output items
|
|
||||||
let list_tools_item = Self::build_mcp_list_tools_item(mcp, args.server_label);
|
|
||||||
let call_item = Self::build_mcp_call_item(
|
|
||||||
args.tool_name,
|
|
||||||
args.args_json,
|
|
||||||
args.output,
|
|
||||||
args.server_label,
|
|
||||||
args.success,
|
|
||||||
args.error,
|
|
||||||
);
|
|
||||||
|
|
||||||
// Find the index of the last message item to insert mcp_call before it
|
|
||||||
let call_insertion_index = output_array
|
|
||||||
.iter()
|
|
||||||
.rposition(|item| item.get("type").and_then(|v| v.as_str()) == Some("message"))
|
|
||||||
.unwrap_or(output_array.len());
|
|
||||||
|
|
||||||
// Insert items in-place for efficiency
|
|
||||||
output_array.insert(call_insertion_index, call_item);
|
|
||||||
output_array.insert(0, list_tools_item);
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn resume_with_tool_result(&self, args: ResumeWithToolArgs<'_>) -> Result<Value, String> {
|
|
||||||
let mut payload2 = args.original_payload.clone();
|
|
||||||
let obj = payload2
|
|
||||||
.as_object_mut()
|
|
||||||
.ok_or_else(|| "payload not an object".to_string())?;
|
|
||||||
|
|
||||||
// Build function_call and tool result items per OpenAI Responses spec
|
|
||||||
let user_item = serde_json::json!({
|
|
||||||
"type": "message",
|
|
||||||
"role": "user",
|
|
||||||
"content": args.original_body.input.clone()
|
|
||||||
});
|
|
||||||
// temp system message since currently only support 1 turn of mcp function call
|
|
||||||
let system_item = serde_json::json!({
|
|
||||||
"type": "message",
|
|
||||||
"role": "system",
|
|
||||||
"content": "please resume with the following tool result, and answer user's question directly, don't trigger any more tool calls"
|
|
||||||
});
|
|
||||||
|
|
||||||
let func_item = serde_json::json!({
|
|
||||||
"type": "function_call",
|
|
||||||
"call_id": args.call_id,
|
|
||||||
"name": args.tool_name,
|
|
||||||
"arguments": args.args_json_str
|
|
||||||
});
|
|
||||||
// Build tool result item as function_call_output per OpenAI Responses spec
|
|
||||||
let tool_item = serde_json::json!({
|
|
||||||
"type": "function_call_output",
|
|
||||||
"call_id": args.call_id,
|
|
||||||
"output": args.output_str
|
|
||||||
});
|
|
||||||
|
|
||||||
obj.insert(
|
|
||||||
"input".to_string(),
|
|
||||||
Value::Array(vec![user_item, system_item, func_item, tool_item]),
|
|
||||||
);
|
|
||||||
|
|
||||||
// Ensure non-streaming and no store to upstream
|
|
||||||
obj.insert("stream".to_string(), Value::Bool(false));
|
|
||||||
obj.insert("store".to_string(), Value::Bool(false));
|
|
||||||
let mut req = self.client.post(args.url).json(&payload2);
|
|
||||||
if let Some(headers) = args.headers {
|
|
||||||
req = apply_request_headers(headers, req, true);
|
|
||||||
}
|
|
||||||
let resp = req
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.map_err(|e| format!("resume request failed: {}", e))?;
|
|
||||||
if !resp.status().is_success() {
|
|
||||||
let status = resp.status();
|
|
||||||
let body = resp.text().await.unwrap_or_default();
|
|
||||||
return Err(format!("resume upstream error {}: {}", status, body));
|
|
||||||
}
|
|
||||||
let mut v = resp
|
|
||||||
.json::<Value>()
|
|
||||||
.await
|
|
||||||
.map_err(|e| format!("parse resume response: {}", e))?;
|
|
||||||
|
|
||||||
if let Some(instr) = &args.original_body.instructions {
|
|
||||||
if let Some(obj) = v.as_object_mut() {
|
|
||||||
obj.entry("instructions")
|
|
||||||
.or_insert(Value::String(instr.clone()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// After resume, mask tools as MCP if request used MCP
|
|
||||||
Self::mask_tools_as_mcp(&mut v, args.original_body);
|
|
||||||
if let Some(obj) = v.as_object_mut() {
|
|
||||||
obj.insert("store".to_string(), Value::Bool(args.original_body.store));
|
|
||||||
}
|
|
||||||
Ok(v)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
|
|||||||
@@ -252,6 +252,7 @@ fn test_responses_request_creation() {
|
|||||||
previous_response_id: None,
|
previous_response_id: None,
|
||||||
reasoning: Some(ResponseReasoningParam {
|
reasoning: Some(ResponseReasoningParam {
|
||||||
effort: Some(ReasoningEffort::Medium),
|
effort: Some(ReasoningEffort::Medium),
|
||||||
|
summary: None,
|
||||||
}),
|
}),
|
||||||
service_tier: ServiceTier::Auto,
|
service_tier: ServiceTier::Auto,
|
||||||
store: true,
|
store: true,
|
||||||
@@ -380,6 +381,7 @@ fn test_usage_conversion() {
|
|||||||
fn test_reasoning_param_default() {
|
fn test_reasoning_param_default() {
|
||||||
let param = ResponseReasoningParam {
|
let param = ResponseReasoningParam {
|
||||||
effort: Some(ReasoningEffort::Medium),
|
effort: Some(ReasoningEffort::Medium),
|
||||||
|
summary: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let json = serde_json::to_string(¶m).unwrap();
|
let json = serde_json::to_string(¶m).unwrap();
|
||||||
@@ -403,6 +405,7 @@ fn test_json_serialization() {
|
|||||||
previous_response_id: None,
|
previous_response_id: None,
|
||||||
reasoning: Some(ResponseReasoningParam {
|
reasoning: Some(ResponseReasoningParam {
|
||||||
effort: Some(ReasoningEffort::High),
|
effort: Some(ReasoningEffort::High),
|
||||||
|
summary: None,
|
||||||
}),
|
}),
|
||||||
service_tier: ServiceTier::Priority,
|
service_tier: ServiceTier::Priority,
|
||||||
store: false,
|
store: false,
|
||||||
@@ -437,3 +440,328 @@ fn test_json_serialization() {
|
|||||||
assert!(parsed.stream);
|
assert!(parsed.stream);
|
||||||
assert_eq!(parsed.tools.len(), 1);
|
assert_eq!(parsed.tools.len(), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_multi_turn_loop_with_mcp() {
|
||||||
|
// This test verifies the multi-turn loop functionality:
|
||||||
|
// 1. Initial request with MCP tools
|
||||||
|
// 2. Mock worker returns function_call
|
||||||
|
// 3. Router executes MCP tool and resumes
|
||||||
|
// 4. Mock worker returns final answer
|
||||||
|
// 5. Verify the complete flow worked
|
||||||
|
|
||||||
|
// Start mock MCP server
|
||||||
|
let mut mcp = MockMCPServer::start().await.expect("start mcp");
|
||||||
|
|
||||||
|
// Write a temp MCP config file
|
||||||
|
let mcp_yaml = format!(
|
||||||
|
"servers:\n - name: mock\n protocol: streamable\n url: {}\n",
|
||||||
|
mcp.url()
|
||||||
|
);
|
||||||
|
let dir = tempfile::tempdir().expect("tmpdir");
|
||||||
|
let cfg_path = dir.path().join("mcp.yaml");
|
||||||
|
std::fs::write(&cfg_path, mcp_yaml).expect("write mcp cfg");
|
||||||
|
std::env::set_var("SGLANG_MCP_CONFIG", cfg_path.to_str().unwrap());
|
||||||
|
|
||||||
|
// Start mock OpenAI worker
|
||||||
|
let mut worker = MockWorker::new(MockWorkerConfig {
|
||||||
|
port: 0,
|
||||||
|
worker_type: WorkerType::Regular,
|
||||||
|
health_status: HealthStatus::Healthy,
|
||||||
|
response_delay_ms: 0,
|
||||||
|
fail_rate: 0.0,
|
||||||
|
});
|
||||||
|
let worker_url = worker.start().await.expect("start worker");
|
||||||
|
|
||||||
|
// Build router config
|
||||||
|
let router_cfg = RouterConfig {
|
||||||
|
mode: RoutingMode::OpenAI {
|
||||||
|
worker_urls: vec![worker_url],
|
||||||
|
},
|
||||||
|
connection_mode: ConnectionMode::Http,
|
||||||
|
policy: PolicyConfig::Random,
|
||||||
|
host: "127.0.0.1".to_string(),
|
||||||
|
port: 0,
|
||||||
|
max_payload_size: 8 * 1024 * 1024,
|
||||||
|
request_timeout_secs: 60,
|
||||||
|
worker_startup_timeout_secs: 5,
|
||||||
|
worker_startup_check_interval_secs: 1,
|
||||||
|
dp_aware: false,
|
||||||
|
api_key: None,
|
||||||
|
discovery: None,
|
||||||
|
metrics: None,
|
||||||
|
log_dir: None,
|
||||||
|
log_level: Some("info".to_string()),
|
||||||
|
request_id_headers: None,
|
||||||
|
max_concurrent_requests: 32,
|
||||||
|
queue_size: 0,
|
||||||
|
queue_timeout_secs: 5,
|
||||||
|
rate_limit_tokens_per_second: None,
|
||||||
|
cors_allowed_origins: vec![],
|
||||||
|
retry: RetryConfig::default(),
|
||||||
|
circuit_breaker: CircuitBreakerConfig::default(),
|
||||||
|
disable_retries: false,
|
||||||
|
disable_circuit_breaker: false,
|
||||||
|
health_check: HealthCheckConfig::default(),
|
||||||
|
enable_igw: false,
|
||||||
|
model_path: None,
|
||||||
|
tokenizer_path: None,
|
||||||
|
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
|
||||||
|
oracle: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 64, None).expect("ctx");
|
||||||
|
let router = RouterFactory::create_router(&Arc::new(ctx))
|
||||||
|
.await
|
||||||
|
.expect("router");
|
||||||
|
|
||||||
|
// Build request with MCP tools
|
||||||
|
let req = ResponsesRequest {
|
||||||
|
background: false,
|
||||||
|
include: None,
|
||||||
|
input: ResponseInput::Text("search for SGLang".to_string()),
|
||||||
|
instructions: Some("Be helpful".to_string()),
|
||||||
|
max_output_tokens: Some(128),
|
||||||
|
max_tool_calls: None, // No limit - test unlimited
|
||||||
|
metadata: None,
|
||||||
|
model: Some("mock-model".to_string()),
|
||||||
|
parallel_tool_calls: true,
|
||||||
|
previous_response_id: None,
|
||||||
|
reasoning: None,
|
||||||
|
service_tier: ServiceTier::Auto,
|
||||||
|
store: true,
|
||||||
|
stream: false,
|
||||||
|
temperature: Some(0.7),
|
||||||
|
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto),
|
||||||
|
tools: vec![ResponseTool {
|
||||||
|
r#type: ResponseToolType::Mcp,
|
||||||
|
server_url: Some(mcp.url()),
|
||||||
|
server_label: Some("mock".to_string()),
|
||||||
|
server_description: Some("Mock MCP server for testing".to_string()),
|
||||||
|
require_approval: Some("never".to_string()),
|
||||||
|
..Default::default()
|
||||||
|
}],
|
||||||
|
top_logprobs: 0,
|
||||||
|
top_p: Some(1.0),
|
||||||
|
truncation: Truncation::Disabled,
|
||||||
|
user: None,
|
||||||
|
request_id: "resp_multi_turn_test".to_string(),
|
||||||
|
priority: 0,
|
||||||
|
frequency_penalty: 0.0,
|
||||||
|
presence_penalty: 0.0,
|
||||||
|
stop: None,
|
||||||
|
top_k: 50,
|
||||||
|
min_p: 0.0,
|
||||||
|
repetition_penalty: 1.0,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Execute the request (this should trigger the multi-turn loop)
|
||||||
|
let response = router.route_responses(None, &req, None).await;
|
||||||
|
|
||||||
|
// Check status
|
||||||
|
assert_eq!(
|
||||||
|
response.status(),
|
||||||
|
axum::http::StatusCode::OK,
|
||||||
|
"Request should succeed"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Read the response body
|
||||||
|
use axum::body::to_bytes;
|
||||||
|
let response_body = response.into_body();
|
||||||
|
let body_bytes = to_bytes(response_body, usize::MAX).await.unwrap();
|
||||||
|
let response_json: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"Multi-turn response: {}",
|
||||||
|
serde_json::to_string_pretty(&response_json).unwrap()
|
||||||
|
);
|
||||||
|
|
||||||
|
// Verify the response structure
|
||||||
|
assert_eq!(response_json["object"], "response");
|
||||||
|
assert_eq!(response_json["status"], "completed");
|
||||||
|
// Note: mock worker generates its own ID, so we just verify it exists
|
||||||
|
assert!(
|
||||||
|
response_json["id"].is_string(),
|
||||||
|
"Response should have an id"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Check that output contains final message
|
||||||
|
let output = response_json["output"]
|
||||||
|
.as_array()
|
||||||
|
.expect("output should be array");
|
||||||
|
assert!(!output.is_empty(), "output should not be empty");
|
||||||
|
|
||||||
|
// Find the final message with text
|
||||||
|
let has_final_text = output.iter().any(|item| {
|
||||||
|
item.get("type")
|
||||||
|
.and_then(|t| t.as_str())
|
||||||
|
.map(|t| t == "message")
|
||||||
|
.unwrap_or(false)
|
||||||
|
&& item
|
||||||
|
.get("content")
|
||||||
|
.and_then(|c| c.as_array())
|
||||||
|
.map(|arr| {
|
||||||
|
arr.iter().any(|part| {
|
||||||
|
part.get("type")
|
||||||
|
.and_then(|t| t.as_str())
|
||||||
|
.map(|t| t == "output_text")
|
||||||
|
.unwrap_or(false)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.unwrap_or(false)
|
||||||
|
});
|
||||||
|
|
||||||
|
assert!(has_final_text, "Should have final text output");
|
||||||
|
|
||||||
|
// Verify tools are masked back to MCP format
|
||||||
|
let tools = response_json["tools"]
|
||||||
|
.as_array()
|
||||||
|
.expect("tools should be array");
|
||||||
|
assert_eq!(tools.len(), 1);
|
||||||
|
assert_eq!(tools[0]["type"], "mcp");
|
||||||
|
assert_eq!(tools[0]["server_label"], "mock");
|
||||||
|
|
||||||
|
// Clean up
|
||||||
|
std::env::remove_var("SGLANG_MCP_CONFIG");
|
||||||
|
worker.stop().await;
|
||||||
|
mcp.stop().await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_max_tool_calls_limit() {
|
||||||
|
// This test verifies that max_tool_calls is respected
|
||||||
|
// Note: The mock worker returns a final answer after one tool call,
|
||||||
|
// so with max_tool_calls=1, it completes normally (doesn't exceed the limit)
|
||||||
|
|
||||||
|
let mut mcp = MockMCPServer::start().await.expect("start mcp");
|
||||||
|
let mcp_yaml = format!(
|
||||||
|
"servers:\n - name: mock\n protocol: streamable\n url: {}\n",
|
||||||
|
mcp.url()
|
||||||
|
);
|
||||||
|
let dir = tempfile::tempdir().expect("tmpdir");
|
||||||
|
let cfg_path = dir.path().join("mcp.yaml");
|
||||||
|
std::fs::write(&cfg_path, mcp_yaml).expect("write mcp cfg");
|
||||||
|
std::env::set_var("SGLANG_MCP_CONFIG", cfg_path.to_str().unwrap());
|
||||||
|
|
||||||
|
let mut worker = MockWorker::new(MockWorkerConfig {
|
||||||
|
port: 0,
|
||||||
|
worker_type: WorkerType::Regular,
|
||||||
|
health_status: HealthStatus::Healthy,
|
||||||
|
response_delay_ms: 0,
|
||||||
|
fail_rate: 0.0,
|
||||||
|
});
|
||||||
|
let worker_url = worker.start().await.expect("start worker");
|
||||||
|
|
||||||
|
let router_cfg = RouterConfig {
|
||||||
|
mode: RoutingMode::OpenAI {
|
||||||
|
worker_urls: vec![worker_url],
|
||||||
|
},
|
||||||
|
connection_mode: ConnectionMode::Http,
|
||||||
|
policy: PolicyConfig::Random,
|
||||||
|
host: "127.0.0.1".to_string(),
|
||||||
|
port: 0,
|
||||||
|
max_payload_size: 8 * 1024 * 1024,
|
||||||
|
request_timeout_secs: 60,
|
||||||
|
worker_startup_timeout_secs: 5,
|
||||||
|
worker_startup_check_interval_secs: 1,
|
||||||
|
dp_aware: false,
|
||||||
|
api_key: None,
|
||||||
|
discovery: None,
|
||||||
|
metrics: None,
|
||||||
|
log_dir: None,
|
||||||
|
log_level: Some("info".to_string()),
|
||||||
|
request_id_headers: None,
|
||||||
|
max_concurrent_requests: 32,
|
||||||
|
queue_size: 0,
|
||||||
|
queue_timeout_secs: 5,
|
||||||
|
rate_limit_tokens_per_second: None,
|
||||||
|
cors_allowed_origins: vec![],
|
||||||
|
retry: RetryConfig::default(),
|
||||||
|
circuit_breaker: CircuitBreakerConfig::default(),
|
||||||
|
disable_retries: false,
|
||||||
|
disable_circuit_breaker: false,
|
||||||
|
health_check: HealthCheckConfig::default(),
|
||||||
|
enable_igw: false,
|
||||||
|
model_path: None,
|
||||||
|
tokenizer_path: None,
|
||||||
|
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
|
||||||
|
oracle: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 64, None).expect("ctx");
|
||||||
|
let router = RouterFactory::create_router(&Arc::new(ctx))
|
||||||
|
.await
|
||||||
|
.expect("router");
|
||||||
|
|
||||||
|
let req = ResponsesRequest {
|
||||||
|
background: false,
|
||||||
|
include: None,
|
||||||
|
input: ResponseInput::Text("test max calls".to_string()),
|
||||||
|
instructions: None,
|
||||||
|
max_output_tokens: Some(128),
|
||||||
|
max_tool_calls: Some(1), // Limit to 1 call
|
||||||
|
metadata: None,
|
||||||
|
model: Some("mock-model".to_string()),
|
||||||
|
parallel_tool_calls: true,
|
||||||
|
previous_response_id: None,
|
||||||
|
reasoning: None,
|
||||||
|
service_tier: ServiceTier::Auto,
|
||||||
|
store: false,
|
||||||
|
stream: false,
|
||||||
|
temperature: Some(0.7),
|
||||||
|
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto),
|
||||||
|
tools: vec![ResponseTool {
|
||||||
|
r#type: ResponseToolType::Mcp,
|
||||||
|
server_url: Some(mcp.url()),
|
||||||
|
server_label: Some("mock".to_string()),
|
||||||
|
..Default::default()
|
||||||
|
}],
|
||||||
|
top_logprobs: 0,
|
||||||
|
top_p: Some(1.0),
|
||||||
|
truncation: Truncation::Disabled,
|
||||||
|
user: None,
|
||||||
|
request_id: "resp_max_calls_test".to_string(),
|
||||||
|
priority: 0,
|
||||||
|
frequency_penalty: 0.0,
|
||||||
|
presence_penalty: 0.0,
|
||||||
|
stop: None,
|
||||||
|
top_k: 50,
|
||||||
|
min_p: 0.0,
|
||||||
|
repetition_penalty: 1.0,
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = router.route_responses(None, &req, None).await;
|
||||||
|
assert_eq!(response.status(), axum::http::StatusCode::OK);
|
||||||
|
|
||||||
|
use axum::body::to_bytes;
|
||||||
|
let response_body = response.into_body();
|
||||||
|
let body_bytes = to_bytes(response_body, usize::MAX).await.unwrap();
|
||||||
|
let response_json: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"Max calls response: {}",
|
||||||
|
serde_json::to_string_pretty(&response_json).unwrap()
|
||||||
|
);
|
||||||
|
|
||||||
|
// With max_tool_calls=1, the mock returns a final answer after 1 call
|
||||||
|
// So it completes normally without exceeding the limit
|
||||||
|
assert_eq!(response_json["status"], "completed");
|
||||||
|
|
||||||
|
// Verify the basic response structure
|
||||||
|
assert!(response_json["id"].is_string());
|
||||||
|
assert_eq!(response_json["object"], "response");
|
||||||
|
|
||||||
|
// The response should have tools masked back to MCP format
|
||||||
|
let tools = response_json["tools"]
|
||||||
|
.as_array()
|
||||||
|
.expect("tools should be array");
|
||||||
|
assert_eq!(tools.len(), 1);
|
||||||
|
assert_eq!(tools[0]["type"], "mcp");
|
||||||
|
|
||||||
|
// Note: To test actual limit exceeding, we would need a mock that keeps
|
||||||
|
// calling tools indefinitely, which would hit max_iterations (safety limit)
|
||||||
|
|
||||||
|
std::env::remove_var("SGLANG_MCP_CONFIG");
|
||||||
|
worker.stop().await;
|
||||||
|
mcp.stop().await;
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user