[router] basic mcp support for openai router response api (#10978)
This commit is contained in:
@@ -644,27 +644,96 @@ async fn responses_handler(
|
||||
}))
|
||||
.into_response()
|
||||
} else {
|
||||
Json(json!({
|
||||
"id": format!("resp-{}", Uuid::new_v4()),
|
||||
"object": "response",
|
||||
"created_at": timestamp,
|
||||
"model": "mock-model",
|
||||
"output": [{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{
|
||||
"type": "output_text",
|
||||
"text": "This is a mock responses output."
|
||||
}]
|
||||
}],
|
||||
"status": "completed",
|
||||
"usage": {
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 5,
|
||||
"total_tokens": 15
|
||||
}
|
||||
}))
|
||||
.into_response()
|
||||
// If tools are provided and this is the first call (no previous_response_id),
|
||||
// emit a single function_tool_call to trigger the router's MCP flow.
|
||||
let has_tools = payload
|
||||
.get("tools")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter().any(|tool| {
|
||||
tool.get("type")
|
||||
.and_then(|t| t.as_str())
|
||||
.map(|t| t == "function")
|
||||
.unwrap_or(false)
|
||||
})
|
||||
})
|
||||
.unwrap_or(false);
|
||||
let has_function_output = payload
|
||||
.get("input")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|items| {
|
||||
items.iter().any(|item| {
|
||||
item.get("type")
|
||||
.and_then(|t| t.as_str())
|
||||
.map(|t| t == "function_call_output")
|
||||
.unwrap_or(false)
|
||||
})
|
||||
})
|
||||
.unwrap_or(false);
|
||||
|
||||
if has_tools && !has_function_output {
|
||||
let rid = format!("resp-{}", Uuid::new_v4());
|
||||
Json(json!({
|
||||
"id": rid,
|
||||
"object": "response",
|
||||
"created_at": timestamp,
|
||||
"model": "mock-model",
|
||||
"output": [{
|
||||
"type": "function_tool_call",
|
||||
"id": "call_1",
|
||||
"name": "brave_web_search",
|
||||
"arguments": "{\"query\":\"SGLang router MCP integration\"}",
|
||||
"status": "in_progress"
|
||||
}],
|
||||
"status": "in_progress",
|
||||
"usage": null
|
||||
}))
|
||||
.into_response()
|
||||
} else if has_tools && has_function_output {
|
||||
Json(json!({
|
||||
"id": format!("resp-{}", Uuid::new_v4()),
|
||||
"object": "response",
|
||||
"created_at": timestamp,
|
||||
"model": "mock-model",
|
||||
"output": [{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{
|
||||
"type": "output_text",
|
||||
"text": "Tool result consumed; here is the final answer."
|
||||
}]
|
||||
}],
|
||||
"status": "completed",
|
||||
"usage": {
|
||||
"input_tokens": 12,
|
||||
"output_tokens": 7,
|
||||
"total_tokens": 19
|
||||
}
|
||||
}))
|
||||
.into_response()
|
||||
} else {
|
||||
Json(json!({
|
||||
"id": format!("resp-{}", Uuid::new_v4()),
|
||||
"object": "response",
|
||||
"created_at": timestamp,
|
||||
"model": "mock-model",
|
||||
"output": [{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{
|
||||
"type": "output_text",
|
||||
"text": "This is a mock responses output."
|
||||
}]
|
||||
}],
|
||||
"status": "completed",
|
||||
"usage": {
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 5,
|
||||
"total_tokens": 15
|
||||
}
|
||||
}))
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,186 @@ use sglang_router_rs::protocols::spec::{
|
||||
ToolChoiceValue, Truncation, UsageInfo,
|
||||
};
|
||||
|
||||
mod common;
|
||||
use common::mock_mcp_server::MockMCPServer;
|
||||
use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType};
|
||||
use sglang_router_rs::config::{
|
||||
CircuitBreakerConfig, ConnectionMode, HealthCheckConfig, PolicyConfig, RetryConfig,
|
||||
RouterConfig, RoutingMode,
|
||||
};
|
||||
use sglang_router_rs::routers::RouterFactory;
|
||||
use sglang_router_rs::server::AppContext;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
|
||||
// Start mock MCP server
|
||||
let mut mcp = MockMCPServer::start().await.expect("start mcp");
|
||||
|
||||
// Write a temp MCP config file
|
||||
let mcp_yaml = format!(
|
||||
"servers:\n - name: mock\n protocol: streamable\n url: {}\n",
|
||||
mcp.url()
|
||||
);
|
||||
let dir = tempfile::tempdir().expect("tmpdir");
|
||||
let cfg_path = dir.path().join("mcp.yaml");
|
||||
std::fs::write(&cfg_path, mcp_yaml).expect("write mcp cfg");
|
||||
|
||||
// Start mock OpenAI worker
|
||||
let mut worker = MockWorker::new(MockWorkerConfig {
|
||||
port: 0,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 0.0,
|
||||
});
|
||||
let worker_url = worker.start().await.expect("start worker");
|
||||
|
||||
// Build router config (HTTP OpenAI mode)
|
||||
let router_cfg = RouterConfig {
|
||||
mode: RoutingMode::OpenAI {
|
||||
worker_urls: vec![worker_url],
|
||||
},
|
||||
connection_mode: ConnectionMode::Http,
|
||||
policy: PolicyConfig::Random,
|
||||
host: "127.0.0.1".to_string(),
|
||||
port: 0,
|
||||
max_payload_size: 8 * 1024 * 1024,
|
||||
request_timeout_secs: 60,
|
||||
worker_startup_timeout_secs: 5,
|
||||
worker_startup_check_interval_secs: 1,
|
||||
dp_aware: false,
|
||||
api_key: None,
|
||||
discovery: None,
|
||||
metrics: None,
|
||||
log_dir: None,
|
||||
log_level: Some("warn".to_string()),
|
||||
request_id_headers: None,
|
||||
max_concurrent_requests: 32,
|
||||
queue_size: 0,
|
||||
queue_timeout_secs: 5,
|
||||
rate_limit_tokens_per_second: None,
|
||||
cors_allowed_origins: vec![],
|
||||
retry: RetryConfig::default(),
|
||||
circuit_breaker: CircuitBreakerConfig::default(),
|
||||
disable_retries: false,
|
||||
disable_circuit_breaker: false,
|
||||
health_check: HealthCheckConfig::default(),
|
||||
enable_igw: false,
|
||||
model_path: None,
|
||||
tokenizer_path: None,
|
||||
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
|
||||
oracle: None,
|
||||
};
|
||||
|
||||
// Create router and context
|
||||
let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 64, None).expect("ctx");
|
||||
let router = RouterFactory::create_router(&Arc::new(ctx))
|
||||
.await
|
||||
.expect("router");
|
||||
|
||||
// Build a simple ResponsesRequest that will trigger the tool call
|
||||
let req = ResponsesRequest {
|
||||
background: false,
|
||||
include: None,
|
||||
input: ResponseInput::Text("search something".to_string()),
|
||||
instructions: Some("Be brief".to_string()),
|
||||
max_output_tokens: Some(64),
|
||||
max_tool_calls: None,
|
||||
metadata: None,
|
||||
model: Some("mock-model".to_string()),
|
||||
parallel_tool_calls: true,
|
||||
previous_response_id: None,
|
||||
reasoning: None,
|
||||
service_tier: sglang_router_rs::protocols::spec::ServiceTier::Auto,
|
||||
store: true,
|
||||
stream: false,
|
||||
temperature: Some(0.2),
|
||||
tool_choice: sglang_router_rs::protocols::spec::ToolChoice::default(),
|
||||
tools: vec![ResponseTool {
|
||||
r#type: ResponseToolType::Mcp,
|
||||
server_url: Some(mcp.url()),
|
||||
authorization: None,
|
||||
server_label: Some("mock".to_string()),
|
||||
server_description: None,
|
||||
require_approval: None,
|
||||
allowed_tools: None,
|
||||
}],
|
||||
top_logprobs: 0,
|
||||
top_p: None,
|
||||
truncation: sglang_router_rs::protocols::spec::Truncation::Disabled,
|
||||
user: None,
|
||||
request_id: "resp_test_mcp_e2e".to_string(),
|
||||
priority: 0,
|
||||
frequency_penalty: 0.0,
|
||||
presence_penalty: 0.0,
|
||||
stop: None,
|
||||
top_k: -1,
|
||||
min_p: 0.0,
|
||||
repetition_penalty: 1.0,
|
||||
};
|
||||
|
||||
let resp = router
|
||||
.route_responses(None, &req, req.model.as_deref())
|
||||
.await;
|
||||
|
||||
assert_eq!(resp.status(), axum::http::StatusCode::OK);
|
||||
|
||||
let body_bytes = axum::body::to_bytes(resp.into_body(), usize::MAX)
|
||||
.await
|
||||
.expect("Failed to read response body");
|
||||
let body_json: serde_json::Value =
|
||||
serde_json::from_slice(&body_bytes).expect("Failed to parse response JSON");
|
||||
|
||||
let output = body_json
|
||||
.get("output")
|
||||
.and_then(|v| v.as_array())
|
||||
.expect("response output missing");
|
||||
assert!(!output.is_empty(), "expected at least one output item");
|
||||
|
||||
let final_text = output
|
||||
.iter()
|
||||
.rev()
|
||||
.filter_map(|entry| entry.get("content"))
|
||||
.filter_map(|content| content.as_array())
|
||||
.flat_map(|parts| parts.iter())
|
||||
.filter_map(|part| part.get("text"))
|
||||
.filter_map(|v| v.as_str())
|
||||
.next();
|
||||
|
||||
if let Some(text) = final_text {
|
||||
assert_eq!(text, "Tool result consumed; here is the final answer.");
|
||||
} else {
|
||||
let call_entry = output.iter().find(|entry| {
|
||||
entry.get("type") == Some(&serde_json::Value::String("function_tool_call".into()))
|
||||
});
|
||||
assert!(call_entry.is_some(), "missing function tool call entry");
|
||||
if let Some(entry) = call_entry {
|
||||
assert_eq!(
|
||||
entry.get("status").and_then(|v| v.as_str()),
|
||||
Some("in_progress"),
|
||||
"function call should be in progress when no content is returned"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let tools = body_json
|
||||
.get("tools")
|
||||
.and_then(|v| v.as_array())
|
||||
.expect("tools array missing");
|
||||
assert_eq!(tools.len(), 1);
|
||||
let tool = tools.first().unwrap();
|
||||
assert_eq!(tool.get("type").and_then(|v| v.as_str()), Some("mcp"));
|
||||
assert_eq!(
|
||||
tool.get("server_label").and_then(|v| v.as_str()),
|
||||
Some("mock")
|
||||
);
|
||||
|
||||
// Cleanup
|
||||
worker.stop().await;
|
||||
mcp.stop().await;
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_responses_request_creation() {
|
||||
let request = ResponsesRequest {
|
||||
@@ -29,6 +209,7 @@ fn test_responses_request_creation() {
|
||||
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto),
|
||||
tools: vec![ResponseTool {
|
||||
r#type: ResponseToolType::WebSearchPreview,
|
||||
..Default::default()
|
||||
}],
|
||||
top_logprobs: 5,
|
||||
top_p: Some(0.9),
|
||||
@@ -179,6 +360,7 @@ fn test_json_serialization() {
|
||||
tool_choice: ToolChoice::Value(ToolChoiceValue::Required),
|
||||
tools: vec![ResponseTool {
|
||||
r#type: ResponseToolType::CodeInterpreter,
|
||||
..Default::default()
|
||||
}],
|
||||
top_logprobs: 10,
|
||||
top_p: Some(0.8),
|
||||
|
||||
Reference in New Issue
Block a user