[router] Add multi-turn tool calling loop support for MCP integration (#11143)
This commit is contained in:
@@ -252,6 +252,7 @@ fn test_responses_request_creation() {
|
||||
previous_response_id: None,
|
||||
reasoning: Some(ResponseReasoningParam {
|
||||
effort: Some(ReasoningEffort::Medium),
|
||||
summary: None,
|
||||
}),
|
||||
service_tier: ServiceTier::Auto,
|
||||
store: true,
|
||||
@@ -380,6 +381,7 @@ fn test_usage_conversion() {
|
||||
fn test_reasoning_param_default() {
|
||||
let param = ResponseReasoningParam {
|
||||
effort: Some(ReasoningEffort::Medium),
|
||||
summary: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(¶m).unwrap();
|
||||
@@ -403,6 +405,7 @@ fn test_json_serialization() {
|
||||
previous_response_id: None,
|
||||
reasoning: Some(ResponseReasoningParam {
|
||||
effort: Some(ReasoningEffort::High),
|
||||
summary: None,
|
||||
}),
|
||||
service_tier: ServiceTier::Priority,
|
||||
store: false,
|
||||
@@ -437,3 +440,328 @@ fn test_json_serialization() {
|
||||
assert!(parsed.stream);
|
||||
assert_eq!(parsed.tools.len(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multi_turn_loop_with_mcp() {
|
||||
// This test verifies the multi-turn loop functionality:
|
||||
// 1. Initial request with MCP tools
|
||||
// 2. Mock worker returns function_call
|
||||
// 3. Router executes MCP tool and resumes
|
||||
// 4. Mock worker returns final answer
|
||||
// 5. Verify the complete flow worked
|
||||
|
||||
// Start mock MCP server
|
||||
let mut mcp = MockMCPServer::start().await.expect("start mcp");
|
||||
|
||||
// Write a temp MCP config file
|
||||
let mcp_yaml = format!(
|
||||
"servers:\n - name: mock\n protocol: streamable\n url: {}\n",
|
||||
mcp.url()
|
||||
);
|
||||
let dir = tempfile::tempdir().expect("tmpdir");
|
||||
let cfg_path = dir.path().join("mcp.yaml");
|
||||
std::fs::write(&cfg_path, mcp_yaml).expect("write mcp cfg");
|
||||
std::env::set_var("SGLANG_MCP_CONFIG", cfg_path.to_str().unwrap());
|
||||
|
||||
// Start mock OpenAI worker
|
||||
let mut worker = MockWorker::new(MockWorkerConfig {
|
||||
port: 0,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 0.0,
|
||||
});
|
||||
let worker_url = worker.start().await.expect("start worker");
|
||||
|
||||
// Build router config
|
||||
let router_cfg = RouterConfig {
|
||||
mode: RoutingMode::OpenAI {
|
||||
worker_urls: vec![worker_url],
|
||||
},
|
||||
connection_mode: ConnectionMode::Http,
|
||||
policy: PolicyConfig::Random,
|
||||
host: "127.0.0.1".to_string(),
|
||||
port: 0,
|
||||
max_payload_size: 8 * 1024 * 1024,
|
||||
request_timeout_secs: 60,
|
||||
worker_startup_timeout_secs: 5,
|
||||
worker_startup_check_interval_secs: 1,
|
||||
dp_aware: false,
|
||||
api_key: None,
|
||||
discovery: None,
|
||||
metrics: None,
|
||||
log_dir: None,
|
||||
log_level: Some("info".to_string()),
|
||||
request_id_headers: None,
|
||||
max_concurrent_requests: 32,
|
||||
queue_size: 0,
|
||||
queue_timeout_secs: 5,
|
||||
rate_limit_tokens_per_second: None,
|
||||
cors_allowed_origins: vec![],
|
||||
retry: RetryConfig::default(),
|
||||
circuit_breaker: CircuitBreakerConfig::default(),
|
||||
disable_retries: false,
|
||||
disable_circuit_breaker: false,
|
||||
health_check: HealthCheckConfig::default(),
|
||||
enable_igw: false,
|
||||
model_path: None,
|
||||
tokenizer_path: None,
|
||||
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
|
||||
oracle: None,
|
||||
};
|
||||
|
||||
let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 64, None).expect("ctx");
|
||||
let router = RouterFactory::create_router(&Arc::new(ctx))
|
||||
.await
|
||||
.expect("router");
|
||||
|
||||
// Build request with MCP tools
|
||||
let req = ResponsesRequest {
|
||||
background: false,
|
||||
include: None,
|
||||
input: ResponseInput::Text("search for SGLang".to_string()),
|
||||
instructions: Some("Be helpful".to_string()),
|
||||
max_output_tokens: Some(128),
|
||||
max_tool_calls: None, // No limit - test unlimited
|
||||
metadata: None,
|
||||
model: Some("mock-model".to_string()),
|
||||
parallel_tool_calls: true,
|
||||
previous_response_id: None,
|
||||
reasoning: None,
|
||||
service_tier: ServiceTier::Auto,
|
||||
store: true,
|
||||
stream: false,
|
||||
temperature: Some(0.7),
|
||||
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto),
|
||||
tools: vec![ResponseTool {
|
||||
r#type: ResponseToolType::Mcp,
|
||||
server_url: Some(mcp.url()),
|
||||
server_label: Some("mock".to_string()),
|
||||
server_description: Some("Mock MCP server for testing".to_string()),
|
||||
require_approval: Some("never".to_string()),
|
||||
..Default::default()
|
||||
}],
|
||||
top_logprobs: 0,
|
||||
top_p: Some(1.0),
|
||||
truncation: Truncation::Disabled,
|
||||
user: None,
|
||||
request_id: "resp_multi_turn_test".to_string(),
|
||||
priority: 0,
|
||||
frequency_penalty: 0.0,
|
||||
presence_penalty: 0.0,
|
||||
stop: None,
|
||||
top_k: 50,
|
||||
min_p: 0.0,
|
||||
repetition_penalty: 1.0,
|
||||
};
|
||||
|
||||
// Execute the request (this should trigger the multi-turn loop)
|
||||
let response = router.route_responses(None, &req, None).await;
|
||||
|
||||
// Check status
|
||||
assert_eq!(
|
||||
response.status(),
|
||||
axum::http::StatusCode::OK,
|
||||
"Request should succeed"
|
||||
);
|
||||
|
||||
// Read the response body
|
||||
use axum::body::to_bytes;
|
||||
let response_body = response.into_body();
|
||||
let body_bytes = to_bytes(response_body, usize::MAX).await.unwrap();
|
||||
let response_json: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
|
||||
|
||||
println!(
|
||||
"Multi-turn response: {}",
|
||||
serde_json::to_string_pretty(&response_json).unwrap()
|
||||
);
|
||||
|
||||
// Verify the response structure
|
||||
assert_eq!(response_json["object"], "response");
|
||||
assert_eq!(response_json["status"], "completed");
|
||||
// Note: mock worker generates its own ID, so we just verify it exists
|
||||
assert!(
|
||||
response_json["id"].is_string(),
|
||||
"Response should have an id"
|
||||
);
|
||||
|
||||
// Check that output contains final message
|
||||
let output = response_json["output"]
|
||||
.as_array()
|
||||
.expect("output should be array");
|
||||
assert!(!output.is_empty(), "output should not be empty");
|
||||
|
||||
// Find the final message with text
|
||||
let has_final_text = output.iter().any(|item| {
|
||||
item.get("type")
|
||||
.and_then(|t| t.as_str())
|
||||
.map(|t| t == "message")
|
||||
.unwrap_or(false)
|
||||
&& item
|
||||
.get("content")
|
||||
.and_then(|c| c.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter().any(|part| {
|
||||
part.get("type")
|
||||
.and_then(|t| t.as_str())
|
||||
.map(|t| t == "output_text")
|
||||
.unwrap_or(false)
|
||||
})
|
||||
})
|
||||
.unwrap_or(false)
|
||||
});
|
||||
|
||||
assert!(has_final_text, "Should have final text output");
|
||||
|
||||
// Verify tools are masked back to MCP format
|
||||
let tools = response_json["tools"]
|
||||
.as_array()
|
||||
.expect("tools should be array");
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0]["type"], "mcp");
|
||||
assert_eq!(tools[0]["server_label"], "mock");
|
||||
|
||||
// Clean up
|
||||
std::env::remove_var("SGLANG_MCP_CONFIG");
|
||||
worker.stop().await;
|
||||
mcp.stop().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_max_tool_calls_limit() {
|
||||
// This test verifies that max_tool_calls is respected
|
||||
// Note: The mock worker returns a final answer after one tool call,
|
||||
// so with max_tool_calls=1, it completes normally (doesn't exceed the limit)
|
||||
|
||||
let mut mcp = MockMCPServer::start().await.expect("start mcp");
|
||||
let mcp_yaml = format!(
|
||||
"servers:\n - name: mock\n protocol: streamable\n url: {}\n",
|
||||
mcp.url()
|
||||
);
|
||||
let dir = tempfile::tempdir().expect("tmpdir");
|
||||
let cfg_path = dir.path().join("mcp.yaml");
|
||||
std::fs::write(&cfg_path, mcp_yaml).expect("write mcp cfg");
|
||||
std::env::set_var("SGLANG_MCP_CONFIG", cfg_path.to_str().unwrap());
|
||||
|
||||
let mut worker = MockWorker::new(MockWorkerConfig {
|
||||
port: 0,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 0.0,
|
||||
});
|
||||
let worker_url = worker.start().await.expect("start worker");
|
||||
|
||||
let router_cfg = RouterConfig {
|
||||
mode: RoutingMode::OpenAI {
|
||||
worker_urls: vec![worker_url],
|
||||
},
|
||||
connection_mode: ConnectionMode::Http,
|
||||
policy: PolicyConfig::Random,
|
||||
host: "127.0.0.1".to_string(),
|
||||
port: 0,
|
||||
max_payload_size: 8 * 1024 * 1024,
|
||||
request_timeout_secs: 60,
|
||||
worker_startup_timeout_secs: 5,
|
||||
worker_startup_check_interval_secs: 1,
|
||||
dp_aware: false,
|
||||
api_key: None,
|
||||
discovery: None,
|
||||
metrics: None,
|
||||
log_dir: None,
|
||||
log_level: Some("info".to_string()),
|
||||
request_id_headers: None,
|
||||
max_concurrent_requests: 32,
|
||||
queue_size: 0,
|
||||
queue_timeout_secs: 5,
|
||||
rate_limit_tokens_per_second: None,
|
||||
cors_allowed_origins: vec![],
|
||||
retry: RetryConfig::default(),
|
||||
circuit_breaker: CircuitBreakerConfig::default(),
|
||||
disable_retries: false,
|
||||
disable_circuit_breaker: false,
|
||||
health_check: HealthCheckConfig::default(),
|
||||
enable_igw: false,
|
||||
model_path: None,
|
||||
tokenizer_path: None,
|
||||
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
|
||||
oracle: None,
|
||||
};
|
||||
|
||||
let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 64, None).expect("ctx");
|
||||
let router = RouterFactory::create_router(&Arc::new(ctx))
|
||||
.await
|
||||
.expect("router");
|
||||
|
||||
let req = ResponsesRequest {
|
||||
background: false,
|
||||
include: None,
|
||||
input: ResponseInput::Text("test max calls".to_string()),
|
||||
instructions: None,
|
||||
max_output_tokens: Some(128),
|
||||
max_tool_calls: Some(1), // Limit to 1 call
|
||||
metadata: None,
|
||||
model: Some("mock-model".to_string()),
|
||||
parallel_tool_calls: true,
|
||||
previous_response_id: None,
|
||||
reasoning: None,
|
||||
service_tier: ServiceTier::Auto,
|
||||
store: false,
|
||||
stream: false,
|
||||
temperature: Some(0.7),
|
||||
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto),
|
||||
tools: vec![ResponseTool {
|
||||
r#type: ResponseToolType::Mcp,
|
||||
server_url: Some(mcp.url()),
|
||||
server_label: Some("mock".to_string()),
|
||||
..Default::default()
|
||||
}],
|
||||
top_logprobs: 0,
|
||||
top_p: Some(1.0),
|
||||
truncation: Truncation::Disabled,
|
||||
user: None,
|
||||
request_id: "resp_max_calls_test".to_string(),
|
||||
priority: 0,
|
||||
frequency_penalty: 0.0,
|
||||
presence_penalty: 0.0,
|
||||
stop: None,
|
||||
top_k: 50,
|
||||
min_p: 0.0,
|
||||
repetition_penalty: 1.0,
|
||||
};
|
||||
|
||||
let response = router.route_responses(None, &req, None).await;
|
||||
assert_eq!(response.status(), axum::http::StatusCode::OK);
|
||||
|
||||
use axum::body::to_bytes;
|
||||
let response_body = response.into_body();
|
||||
let body_bytes = to_bytes(response_body, usize::MAX).await.unwrap();
|
||||
let response_json: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
|
||||
|
||||
println!(
|
||||
"Max calls response: {}",
|
||||
serde_json::to_string_pretty(&response_json).unwrap()
|
||||
);
|
||||
|
||||
// With max_tool_calls=1, the mock returns a final answer after 1 call
|
||||
// So it completes normally without exceeding the limit
|
||||
assert_eq!(response_json["status"], "completed");
|
||||
|
||||
// Verify the basic response structure
|
||||
assert!(response_json["id"].is_string());
|
||||
assert_eq!(response_json["object"], "response");
|
||||
|
||||
// The response should have tools masked back to MCP format
|
||||
let tools = response_json["tools"]
|
||||
.as_array()
|
||||
.expect("tools should be array");
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0]["type"], "mcp");
|
||||
|
||||
// Note: To test actual limit exceeding, we would need a mock that keeps
|
||||
// calling tools indefinitely, which would hit max_iterations (safety limit)
|
||||
|
||||
std::env::remove_var("SGLANG_MCP_CONFIG");
|
||||
worker.stop().await;
|
||||
mcp.stop().await;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user