[router] basic mcp support for openai router response api (#10978)

This commit is contained in:
Keyang Ru
2025-09-27 18:49:33 -07:00
committed by GitHub
parent c1c8dd1dd0
commit 72392f2908
4 changed files with 730 additions and 28 deletions

View File

@@ -683,6 +683,33 @@ pub struct CompletionStreamChoice {
pub struct ResponseTool {
#[serde(rename = "type")]
pub r#type: ResponseToolType,
// MCP-specific fields (used when type == "mcp")
#[serde(skip_serializing_if = "Option::is_none")]
pub server_url: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub authorization: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub server_label: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub server_description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub require_approval: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub allowed_tools: Option<Vec<String>>,
}
impl Default for ResponseTool {
fn default() -> Self {
Self {
r#type: ResponseToolType::WebSearchPreview,
server_url: None,
authorization: None,
server_label: None,
server_description: None,
require_approval: None,
allowed_tools: None,
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
@@ -690,6 +717,7 @@ pub struct ResponseTool {
pub enum ResponseToolType {
WebSearchPreview,
CodeInterpreter,
Mcp,
}
#[derive(Debug, Clone, Deserialize, Serialize)]

View File

@@ -6,8 +6,8 @@ use crate::data_connector::{ResponseId, SharedResponseStorage, StoredResponse};
use crate::protocols::spec::{
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
ResponseContentPart, ResponseInput, ResponseInputOutputItem, ResponseOutputItem,
ResponseStatus, ResponseTextFormat, ResponsesGetParams, ResponsesRequest, ResponsesResponse,
TextFormatType,
ResponseStatus, ResponseTextFormat, ResponseTool, ResponseToolType, ResponsesGetParams,
ResponsesRequest, ResponsesResponse, TextFormatType,
};
use crate::routers::header_utils::{apply_request_headers, preserve_response_headers};
use async_trait::async_trait;
@@ -20,7 +20,14 @@ use axum::{
use bytes::Bytes;
use futures_util::StreamExt;
use serde_json::{json, to_value, Value};
use std::{any::Any, borrow::Cow, collections::HashMap, io, sync::atomic::AtomicBool};
use std::{
any::Any,
borrow::Cow,
collections::HashMap,
io,
sync::{atomic::AtomicBool, Arc},
time::SystemTime,
};
use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{error, info, warn};
@@ -37,6 +44,8 @@ pub struct OpenAIRouter {
healthy: AtomicBool,
/// Response storage for managing conversation history
response_storage: SharedResponseStorage,
/// Optional MCP manager (enabled via config presence)
mcp_manager: Option<Arc<crate::mcp::McpClientManager>>,
}
impl std::fmt::Debug for OpenAIRouter {
@@ -210,12 +219,33 @@ impl OpenAIRouter {
let circuit_breaker = CircuitBreaker::with_config(core_cb_config);
// Optional MCP manager activation via env var path (config-driven gate)
let mcp_manager = match std::env::var("SGLANG_MCP_CONFIG").ok() {
Some(path) if !path.trim().is_empty() => {
match crate::mcp::McpConfig::from_file(&path).await {
Ok(cfg) => match crate::mcp::McpClientManager::new(cfg).await {
Ok(mgr) => Some(Arc::new(mgr)),
Err(err) => {
warn!("Failed to initialize MCP manager: {}", err);
None
}
},
Err(err) => {
warn!("Failed to load MCP config from '{}': {}", path, err);
None
}
}
}
_ => None,
};
Ok(Self {
client,
base_url,
circuit_breaker,
healthy: AtomicBool::new(true),
response_storage,
mcp_manager,
})
}
@@ -223,10 +253,79 @@ impl OpenAIRouter {
&self,
url: String,
headers: Option<&HeaderMap>,
payload: Value,
mut payload: Value,
original_body: &ResponsesRequest,
original_previous_response_id: Option<String>,
) -> Response {
// Request-scoped MCP: build from request tools if provided; otherwise fall back to router-level MCP
let req_mcp_manager = Self::mcp_manager_from_request_tools(&original_body.tools).await;
let active_mcp = req_mcp_manager.as_ref().or(self.mcp_manager.as_ref());
// If the client requested MCP but we couldn't initialize it, fail early with a clear error
let requested_mcp = original_body
.tools
.iter()
.any(|t| matches!(t.r#type, ResponseToolType::Mcp));
if requested_mcp && active_mcp.is_none() {
return (
StatusCode::BAD_GATEWAY,
json!({
"error": {
"message": "MCP server unavailable or failed to initialize from request tools",
"type": "mcp_unavailable",
"param": "tools",
}
})
.to_string(),
)
.into_response();
}
// If MCP is active, mirror one function tool into the outgoing payload
if let Some(mcp) = active_mcp {
if let Some(obj) = payload.as_object_mut() {
// Remove any non-function tools (e.g., custom "mcp" items) from outgoing payload
if let Some(v) = obj.get_mut("tools") {
if let Some(arr) = v.as_array_mut() {
arr.retain(|item| {
item.get("type")
.and_then(|v| v.as_str())
.map(|s| s == "function")
.unwrap_or(false)
});
if arr.is_empty() {
obj.remove("tools");
obj.insert(
"tool_choice".to_string(),
Value::String("none".to_string()),
);
}
}
}
// Build function tools for all discovered MCP tools
let mut tools_json = Vec::new();
let tools = mcp.list_tools();
for t in tools {
let parameters = t.parameters.clone().unwrap_or(serde_json::json!({
"type": "object",
"properties": {},
"additionalProperties": false
}));
let tool = serde_json::json!({
"type": "function",
"name": t.name,
"description": t.description,
"parameters": parameters
});
tools_json.push(tool);
}
if !tools_json.is_empty() {
obj.insert("tools".to_string(), Value::Array(tools_json));
// Ensure tool_choice auto to allow model planning
obj.insert("tool_choice".to_string(), Value::String("auto".to_string()));
}
}
}
let request_builder = self.client.post(&url).json(&payload);
// Apply headers with filtering
@@ -239,7 +338,6 @@ impl OpenAIRouter {
match request_builder.send().await {
Ok(response) => {
let status = response.status();
if !status.is_success() {
let error_text = response
.text()
@@ -290,16 +388,124 @@ impl OpenAIRouter {
obj.insert("store".to_string(), Value::Bool(original_body.store));
}
let mut final_response_json = openai_response_json;
if let Some(mcp) = active_mcp {
if let Some((call_id, tool_name, args_json_str)) =
Self::extract_function_call(&final_response_json)
{
info!(
"Detected function call: name={}, call_id={}, args={}",
tool_name, call_id, args_json_str
);
let call_started = SystemTime::now();
let call_result =
Self::execute_mcp_call(mcp, &tool_name, &args_json_str).await;
let call_duration_ms =
call_started.elapsed().unwrap_or_default().as_millis();
let (output_payload, call_ok, call_error) = match call_result {
Ok((server, out)) => {
info!(
call_id = %call_id,
tool_name = %tool_name,
server = %server,
duration_ms = call_duration_ms,
"MCP tool call succeeded"
);
(out, true, None)
}
Err(err) => {
warn!(
call_id = %call_id,
tool_name = %tool_name,
duration_ms = call_duration_ms,
error = %err,
"MCP tool call failed"
);
(
serde_json::json!({
"error": err
})
.to_string(),
false,
Some(err),
)
}
};
match self
.resume_with_tool_result(ResumeWithToolArgs {
url: &url,
headers,
original_payload: &payload,
call_id: &call_id,
tool_name: &tool_name,
args_json_str: &args_json_str,
output_str: &output_payload,
original_body,
})
.await
{
Ok(mut resumed_json) => {
if !call_ok {
if let Some(obj) = resumed_json.as_object_mut() {
let metadata_value =
obj.entry("metadata").or_insert_with(|| {
Value::Object(serde_json::Map::new())
});
if let Some(metadata) =
metadata_value.as_object_mut()
{
if let Some(err_msg) = call_error.as_ref() {
metadata.insert(
"mcp_error".to_string(),
Value::String(err_msg.clone()),
);
}
}
}
}
final_response_json = resumed_json;
}
Err(err) => {
warn!("Failed to resume with tool result: {}", err);
let error_body = json!({
"error": {
"message": format!(
"Failed to resume with tool result: {}",
err
),
"type": "internal_error",
}
})
.to_string();
return (
StatusCode::INTERNAL_SERVER_ERROR,
[("content-type", "application/json")],
error_body,
)
.into_response();
}
}
} else {
info!("No function call found in upstream response; skipping MCP");
}
}
Self::mask_tools_as_mcp(&mut final_response_json, original_body);
if original_body.store {
if let Err(e) = self
.store_response_internal(&openai_response_json, original_body)
.store_response_internal(&final_response_json, original_body)
.await
{
warn!("Failed to store response: {}", e);
}
}
match serde_json::to_string(&openai_response_json) {
match serde_json::to_string(&final_response_json) {
Ok(json_str) => (
StatusCode::OK,
[("content-type", "application/json")],
@@ -334,6 +540,49 @@ impl OpenAIRouter {
}
}
/// Build a request-scoped MCP manager from request tools, if present.
async fn mcp_manager_from_request_tools(
tools: &[ResponseTool],
) -> Option<Arc<crate::mcp::McpClientManager>> {
let tool = tools
.iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp) && t.server_url.is_some())?;
let server_url = tool.server_url.as_ref()?.trim().to_string();
if !(server_url.starts_with("http://") || server_url.starts_with("https://")) {
warn!(
"Ignoring MCP server_url with unsupported scheme: {}",
server_url
);
return None;
}
let name = tool
.server_label
.clone()
.unwrap_or_else(|| "request-mcp".to_string());
let token = tool.authorization.clone();
let transport = if server_url.contains("/sse") {
crate::mcp::McpTransport::Sse {
url: server_url,
token,
}
} else {
crate::mcp::McpTransport::Streamable {
url: server_url,
token,
}
};
let cfg = crate::mcp::McpConfig {
servers: vec![crate::mcp::McpServerConfig { name, transport }],
};
match crate::mcp::McpClientManager::new(cfg).await {
Ok(mgr) => Some(Arc::new(mgr)),
Err(err) => {
warn!("Failed to initialize request-scoped MCP manager: {}", err);
None
}
}
}
async fn handle_streaming_response(
&self,
url: String,
@@ -765,6 +1014,180 @@ impl OpenAIRouter {
}
}
struct ResumeWithToolArgs<'a> {
url: &'a str,
headers: Option<&'a HeaderMap>,
original_payload: &'a Value,
call_id: &'a str,
tool_name: &'a str,
args_json_str: &'a str,
output_str: &'a str,
original_body: &'a ResponsesRequest,
}
impl OpenAIRouter {
fn extract_function_call(resp: &Value) -> Option<(String, String, String)> {
let output = resp.get("output")?.as_array()?;
for item in output {
let obj = item.as_object()?;
let t = obj.get("type")?.as_str()?;
if t == "function_tool_call" || t == "function_call" {
let call_id = obj
.get("call_id")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.or_else(|| {
obj.get("id")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
})?;
let name = obj.get("name")?.as_str()?.to_string();
let arguments = obj.get("arguments")?.as_str()?.to_string();
return Some((call_id, name, arguments));
}
}
None
}
/// Replace returned tools with the original request's MCP tool block (if present) so
/// external clients see MCP semantics rather than internal function tools.
fn mask_tools_as_mcp(resp: &mut Value, original_body: &ResponsesRequest) {
let mcp_tool = original_body
.tools
.iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp) && t.server_url.is_some());
let Some(t) = mcp_tool else {
return;
};
let mut m = serde_json::Map::new();
m.insert("type".to_string(), Value::String("mcp".to_string()));
if let Some(label) = &t.server_label {
m.insert("server_label".to_string(), Value::String(label.clone()));
}
if let Some(url) = &t.server_url {
m.insert("server_url".to_string(), Value::String(url.clone()));
}
if let Some(desc) = &t.server_description {
m.insert(
"server_description".to_string(),
Value::String(desc.clone()),
);
}
if let Some(req) = &t.require_approval {
m.insert("require_approval".to_string(), Value::String(req.clone()));
}
if let Some(allowed) = &t.allowed_tools {
m.insert(
"allowed_tools".to_string(),
Value::Array(allowed.iter().map(|s| Value::String(s.clone())).collect()),
);
}
if let Some(obj) = resp.as_object_mut() {
obj.insert("tools".to_string(), Value::Array(vec![Value::Object(m)]));
obj.entry("tool_choice")
.or_insert(Value::String("auto".to_string()));
}
}
async fn execute_mcp_call(
mcp_mgr: &Arc<crate::mcp::McpClientManager>,
tool_name: &str,
args_json_str: &str,
) -> Result<(String, String), String> {
let args_value: Value =
serde_json::from_str(args_json_str).map_err(|e| format!("parse tool args: {}", e))?;
let args_obj = args_value.as_object().cloned();
let server_name = mcp_mgr
.get_tool(tool_name)
.map(|t| t.server)
.ok_or_else(|| format!("tool not found: {}", tool_name))?;
let result = mcp_mgr
.call_tool(tool_name, args_obj)
.await
.map_err(|e| format!("tool call failed: {}", e))?;
let output_str = serde_json::to_string(&result)
.map_err(|e| format!("Failed to serialize tool result: {}", e))?;
Ok((server_name, output_str))
}
async fn resume_with_tool_result(&self, args: ResumeWithToolArgs<'_>) -> Result<Value, String> {
let mut payload2 = args.original_payload.clone();
let obj = payload2
.as_object_mut()
.ok_or_else(|| "payload not an object".to_string())?;
// Build function_call and tool result items per OpenAI Responses spec
let user_item = serde_json::json!({
"type": "message",
"role": "user",
"content": args.original_body.input.clone()
});
// temp system message since currently only support 1 turn of mcp function call
let system_item = serde_json::json!({
"type": "message",
"role": "system",
"content": "please resume with the following tool result, and answer user's question directly, don't trigger any more tool calls"
});
let func_item = serde_json::json!({
"type": "function_call",
"call_id": args.call_id,
"name": args.tool_name,
"arguments": args.args_json_str
});
// Build tool result item as function_call_output per OpenAI Responses spec
let tool_item = serde_json::json!({
"type": "function_call_output",
"call_id": args.call_id,
"output": args.output_str
});
obj.insert(
"input".to_string(),
Value::Array(vec![user_item, system_item, func_item, tool_item]),
);
// Ensure non-streaming and no store to upstream
obj.insert("stream".to_string(), Value::Bool(false));
obj.insert("store".to_string(), Value::Bool(false));
let mut req = self.client.post(args.url).json(&payload2);
if let Some(headers) = args.headers {
req = apply_request_headers(headers, req, true);
}
let resp = req
.send()
.await
.map_err(|e| format!("resume request failed: {}", e))?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(format!("resume upstream error {}: {}", status, body));
}
let mut v = resp
.json::<Value>()
.await
.map_err(|e| format!("parse resume response: {}", e))?;
if let Some(instr) = &args.original_body.instructions {
if let Some(obj) = v.as_object_mut() {
obj.entry("instructions")
.or_insert(Value::String(instr.clone()));
}
}
// After resume, mask tools as MCP if request used MCP
Self::mask_tools_as_mcp(&mut v, args.original_body);
if let Some(obj) = v.as_object_mut() {
obj.insert("store".to_string(), Value::Bool(args.original_body.store));
}
Ok(v)
}
}
#[async_trait]
impl super::super::RouterTrait for OpenAIRouter {
fn as_any(&self) -> &dyn Any {