[router] basic mcp support for openai router response api (#10978)
This commit is contained in:
@@ -683,6 +683,33 @@ pub struct CompletionStreamChoice {
|
||||
pub struct ResponseTool {
|
||||
#[serde(rename = "type")]
|
||||
pub r#type: ResponseToolType,
|
||||
// MCP-specific fields (used when type == "mcp")
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub server_url: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub authorization: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub server_label: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub server_description: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub require_approval: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub allowed_tools: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
impl Default for ResponseTool {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
r#type: ResponseToolType::WebSearchPreview,
|
||||
server_url: None,
|
||||
authorization: None,
|
||||
server_label: None,
|
||||
server_description: None,
|
||||
require_approval: None,
|
||||
allowed_tools: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
@@ -690,6 +717,7 @@ pub struct ResponseTool {
|
||||
pub enum ResponseToolType {
|
||||
WebSearchPreview,
|
||||
CodeInterpreter,
|
||||
Mcp,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
|
||||
@@ -6,8 +6,8 @@ use crate::data_connector::{ResponseId, SharedResponseStorage, StoredResponse};
|
||||
use crate::protocols::spec::{
|
||||
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
|
||||
ResponseContentPart, ResponseInput, ResponseInputOutputItem, ResponseOutputItem,
|
||||
ResponseStatus, ResponseTextFormat, ResponsesGetParams, ResponsesRequest, ResponsesResponse,
|
||||
TextFormatType,
|
||||
ResponseStatus, ResponseTextFormat, ResponseTool, ResponseToolType, ResponsesGetParams,
|
||||
ResponsesRequest, ResponsesResponse, TextFormatType,
|
||||
};
|
||||
use crate::routers::header_utils::{apply_request_headers, preserve_response_headers};
|
||||
use async_trait::async_trait;
|
||||
@@ -20,7 +20,14 @@ use axum::{
|
||||
use bytes::Bytes;
|
||||
use futures_util::StreamExt;
|
||||
use serde_json::{json, to_value, Value};
|
||||
use std::{any::Any, borrow::Cow, collections::HashMap, io, sync::atomic::AtomicBool};
|
||||
use std::{
|
||||
any::Any,
|
||||
borrow::Cow,
|
||||
collections::HashMap,
|
||||
io,
|
||||
sync::{atomic::AtomicBool, Arc},
|
||||
time::SystemTime,
|
||||
};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tracing::{error, info, warn};
|
||||
@@ -37,6 +44,8 @@ pub struct OpenAIRouter {
|
||||
healthy: AtomicBool,
|
||||
/// Response storage for managing conversation history
|
||||
response_storage: SharedResponseStorage,
|
||||
/// Optional MCP manager (enabled via config presence)
|
||||
mcp_manager: Option<Arc<crate::mcp::McpClientManager>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for OpenAIRouter {
|
||||
@@ -210,12 +219,33 @@ impl OpenAIRouter {
|
||||
|
||||
let circuit_breaker = CircuitBreaker::with_config(core_cb_config);
|
||||
|
||||
// Optional MCP manager activation via env var path (config-driven gate)
|
||||
let mcp_manager = match std::env::var("SGLANG_MCP_CONFIG").ok() {
|
||||
Some(path) if !path.trim().is_empty() => {
|
||||
match crate::mcp::McpConfig::from_file(&path).await {
|
||||
Ok(cfg) => match crate::mcp::McpClientManager::new(cfg).await {
|
||||
Ok(mgr) => Some(Arc::new(mgr)),
|
||||
Err(err) => {
|
||||
warn!("Failed to initialize MCP manager: {}", err);
|
||||
None
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
warn!("Failed to load MCP config from '{}': {}", path, err);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
base_url,
|
||||
circuit_breaker,
|
||||
healthy: AtomicBool::new(true),
|
||||
response_storage,
|
||||
mcp_manager,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -223,10 +253,79 @@ impl OpenAIRouter {
|
||||
&self,
|
||||
url: String,
|
||||
headers: Option<&HeaderMap>,
|
||||
payload: Value,
|
||||
mut payload: Value,
|
||||
original_body: &ResponsesRequest,
|
||||
original_previous_response_id: Option<String>,
|
||||
) -> Response {
|
||||
// Request-scoped MCP: build from request tools if provided; otherwise fall back to router-level MCP
|
||||
let req_mcp_manager = Self::mcp_manager_from_request_tools(&original_body.tools).await;
|
||||
let active_mcp = req_mcp_manager.as_ref().or(self.mcp_manager.as_ref());
|
||||
|
||||
// If the client requested MCP but we couldn't initialize it, fail early with a clear error
|
||||
let requested_mcp = original_body
|
||||
.tools
|
||||
.iter()
|
||||
.any(|t| matches!(t.r#type, ResponseToolType::Mcp));
|
||||
if requested_mcp && active_mcp.is_none() {
|
||||
return (
|
||||
StatusCode::BAD_GATEWAY,
|
||||
json!({
|
||||
"error": {
|
||||
"message": "MCP server unavailable or failed to initialize from request tools",
|
||||
"type": "mcp_unavailable",
|
||||
"param": "tools",
|
||||
}
|
||||
})
|
||||
.to_string(),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
// If MCP is active, mirror one function tool into the outgoing payload
|
||||
if let Some(mcp) = active_mcp {
|
||||
if let Some(obj) = payload.as_object_mut() {
|
||||
// Remove any non-function tools (e.g., custom "mcp" items) from outgoing payload
|
||||
if let Some(v) = obj.get_mut("tools") {
|
||||
if let Some(arr) = v.as_array_mut() {
|
||||
arr.retain(|item| {
|
||||
item.get("type")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s == "function")
|
||||
.unwrap_or(false)
|
||||
});
|
||||
if arr.is_empty() {
|
||||
obj.remove("tools");
|
||||
obj.insert(
|
||||
"tool_choice".to_string(),
|
||||
Value::String("none".to_string()),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Build function tools for all discovered MCP tools
|
||||
let mut tools_json = Vec::new();
|
||||
let tools = mcp.list_tools();
|
||||
for t in tools {
|
||||
let parameters = t.parameters.clone().unwrap_or(serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"additionalProperties": false
|
||||
}));
|
||||
let tool = serde_json::json!({
|
||||
"type": "function",
|
||||
"name": t.name,
|
||||
"description": t.description,
|
||||
"parameters": parameters
|
||||
});
|
||||
tools_json.push(tool);
|
||||
}
|
||||
if !tools_json.is_empty() {
|
||||
obj.insert("tools".to_string(), Value::Array(tools_json));
|
||||
// Ensure tool_choice auto to allow model planning
|
||||
obj.insert("tool_choice".to_string(), Value::String("auto".to_string()));
|
||||
}
|
||||
}
|
||||
}
|
||||
let request_builder = self.client.post(&url).json(&payload);
|
||||
|
||||
// Apply headers with filtering
|
||||
@@ -239,7 +338,6 @@ impl OpenAIRouter {
|
||||
match request_builder.send().await {
|
||||
Ok(response) => {
|
||||
let status = response.status();
|
||||
|
||||
if !status.is_success() {
|
||||
let error_text = response
|
||||
.text()
|
||||
@@ -290,16 +388,124 @@ impl OpenAIRouter {
|
||||
obj.insert("store".to_string(), Value::Bool(original_body.store));
|
||||
}
|
||||
|
||||
let mut final_response_json = openai_response_json;
|
||||
|
||||
if let Some(mcp) = active_mcp {
|
||||
if let Some((call_id, tool_name, args_json_str)) =
|
||||
Self::extract_function_call(&final_response_json)
|
||||
{
|
||||
info!(
|
||||
"Detected function call: name={}, call_id={}, args={}",
|
||||
tool_name, call_id, args_json_str
|
||||
);
|
||||
|
||||
let call_started = SystemTime::now();
|
||||
let call_result =
|
||||
Self::execute_mcp_call(mcp, &tool_name, &args_json_str).await;
|
||||
let call_duration_ms =
|
||||
call_started.elapsed().unwrap_or_default().as_millis();
|
||||
|
||||
let (output_payload, call_ok, call_error) = match call_result {
|
||||
Ok((server, out)) => {
|
||||
info!(
|
||||
call_id = %call_id,
|
||||
tool_name = %tool_name,
|
||||
server = %server,
|
||||
duration_ms = call_duration_ms,
|
||||
"MCP tool call succeeded"
|
||||
);
|
||||
(out, true, None)
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(
|
||||
call_id = %call_id,
|
||||
tool_name = %tool_name,
|
||||
duration_ms = call_duration_ms,
|
||||
error = %err,
|
||||
"MCP tool call failed"
|
||||
);
|
||||
(
|
||||
serde_json::json!({
|
||||
"error": err
|
||||
})
|
||||
.to_string(),
|
||||
false,
|
||||
Some(err),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
match self
|
||||
.resume_with_tool_result(ResumeWithToolArgs {
|
||||
url: &url,
|
||||
headers,
|
||||
original_payload: &payload,
|
||||
call_id: &call_id,
|
||||
tool_name: &tool_name,
|
||||
args_json_str: &args_json_str,
|
||||
output_str: &output_payload,
|
||||
original_body,
|
||||
})
|
||||
.await
|
||||
{
|
||||
Ok(mut resumed_json) => {
|
||||
if !call_ok {
|
||||
if let Some(obj) = resumed_json.as_object_mut() {
|
||||
let metadata_value =
|
||||
obj.entry("metadata").or_insert_with(|| {
|
||||
Value::Object(serde_json::Map::new())
|
||||
});
|
||||
if let Some(metadata) =
|
||||
metadata_value.as_object_mut()
|
||||
{
|
||||
if let Some(err_msg) = call_error.as_ref() {
|
||||
metadata.insert(
|
||||
"mcp_error".to_string(),
|
||||
Value::String(err_msg.clone()),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
final_response_json = resumed_json;
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("Failed to resume with tool result: {}", err);
|
||||
let error_body = json!({
|
||||
"error": {
|
||||
"message": format!(
|
||||
"Failed to resume with tool result: {}",
|
||||
err
|
||||
),
|
||||
"type": "internal_error",
|
||||
}
|
||||
})
|
||||
.to_string();
|
||||
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
[("content-type", "application/json")],
|
||||
error_body,
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
info!("No function call found in upstream response; skipping MCP");
|
||||
}
|
||||
}
|
||||
|
||||
Self::mask_tools_as_mcp(&mut final_response_json, original_body);
|
||||
if original_body.store {
|
||||
if let Err(e) = self
|
||||
.store_response_internal(&openai_response_json, original_body)
|
||||
.store_response_internal(&final_response_json, original_body)
|
||||
.await
|
||||
{
|
||||
warn!("Failed to store response: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
match serde_json::to_string(&openai_response_json) {
|
||||
match serde_json::to_string(&final_response_json) {
|
||||
Ok(json_str) => (
|
||||
StatusCode::OK,
|
||||
[("content-type", "application/json")],
|
||||
@@ -334,6 +540,49 @@ impl OpenAIRouter {
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a request-scoped MCP manager from request tools, if present.
|
||||
async fn mcp_manager_from_request_tools(
|
||||
tools: &[ResponseTool],
|
||||
) -> Option<Arc<crate::mcp::McpClientManager>> {
|
||||
let tool = tools
|
||||
.iter()
|
||||
.find(|t| matches!(t.r#type, ResponseToolType::Mcp) && t.server_url.is_some())?;
|
||||
let server_url = tool.server_url.as_ref()?.trim().to_string();
|
||||
if !(server_url.starts_with("http://") || server_url.starts_with("https://")) {
|
||||
warn!(
|
||||
"Ignoring MCP server_url with unsupported scheme: {}",
|
||||
server_url
|
||||
);
|
||||
return None;
|
||||
}
|
||||
let name = tool
|
||||
.server_label
|
||||
.clone()
|
||||
.unwrap_or_else(|| "request-mcp".to_string());
|
||||
let token = tool.authorization.clone();
|
||||
let transport = if server_url.contains("/sse") {
|
||||
crate::mcp::McpTransport::Sse {
|
||||
url: server_url,
|
||||
token,
|
||||
}
|
||||
} else {
|
||||
crate::mcp::McpTransport::Streamable {
|
||||
url: server_url,
|
||||
token,
|
||||
}
|
||||
};
|
||||
let cfg = crate::mcp::McpConfig {
|
||||
servers: vec![crate::mcp::McpServerConfig { name, transport }],
|
||||
};
|
||||
match crate::mcp::McpClientManager::new(cfg).await {
|
||||
Ok(mgr) => Some(Arc::new(mgr)),
|
||||
Err(err) => {
|
||||
warn!("Failed to initialize request-scoped MCP manager: {}", err);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_streaming_response(
|
||||
&self,
|
||||
url: String,
|
||||
@@ -765,6 +1014,180 @@ impl OpenAIRouter {
|
||||
}
|
||||
}
|
||||
|
||||
struct ResumeWithToolArgs<'a> {
|
||||
url: &'a str,
|
||||
headers: Option<&'a HeaderMap>,
|
||||
original_payload: &'a Value,
|
||||
call_id: &'a str,
|
||||
tool_name: &'a str,
|
||||
args_json_str: &'a str,
|
||||
output_str: &'a str,
|
||||
original_body: &'a ResponsesRequest,
|
||||
}
|
||||
|
||||
impl OpenAIRouter {
|
||||
fn extract_function_call(resp: &Value) -> Option<(String, String, String)> {
|
||||
let output = resp.get("output")?.as_array()?;
|
||||
for item in output {
|
||||
let obj = item.as_object()?;
|
||||
let t = obj.get("type")?.as_str()?;
|
||||
if t == "function_tool_call" || t == "function_call" {
|
||||
let call_id = obj
|
||||
.get("call_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string())
|
||||
.or_else(|| {
|
||||
obj.get("id")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string())
|
||||
})?;
|
||||
let name = obj.get("name")?.as_str()?.to_string();
|
||||
let arguments = obj.get("arguments")?.as_str()?.to_string();
|
||||
return Some((call_id, name, arguments));
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Replace returned tools with the original request's MCP tool block (if present) so
|
||||
/// external clients see MCP semantics rather than internal function tools.
|
||||
fn mask_tools_as_mcp(resp: &mut Value, original_body: &ResponsesRequest) {
|
||||
let mcp_tool = original_body
|
||||
.tools
|
||||
.iter()
|
||||
.find(|t| matches!(t.r#type, ResponseToolType::Mcp) && t.server_url.is_some());
|
||||
let Some(t) = mcp_tool else {
|
||||
return;
|
||||
};
|
||||
|
||||
let mut m = serde_json::Map::new();
|
||||
m.insert("type".to_string(), Value::String("mcp".to_string()));
|
||||
if let Some(label) = &t.server_label {
|
||||
m.insert("server_label".to_string(), Value::String(label.clone()));
|
||||
}
|
||||
if let Some(url) = &t.server_url {
|
||||
m.insert("server_url".to_string(), Value::String(url.clone()));
|
||||
}
|
||||
if let Some(desc) = &t.server_description {
|
||||
m.insert(
|
||||
"server_description".to_string(),
|
||||
Value::String(desc.clone()),
|
||||
);
|
||||
}
|
||||
if let Some(req) = &t.require_approval {
|
||||
m.insert("require_approval".to_string(), Value::String(req.clone()));
|
||||
}
|
||||
if let Some(allowed) = &t.allowed_tools {
|
||||
m.insert(
|
||||
"allowed_tools".to_string(),
|
||||
Value::Array(allowed.iter().map(|s| Value::String(s.clone())).collect()),
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(obj) = resp.as_object_mut() {
|
||||
obj.insert("tools".to_string(), Value::Array(vec![Value::Object(m)]));
|
||||
obj.entry("tool_choice")
|
||||
.or_insert(Value::String("auto".to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
async fn execute_mcp_call(
|
||||
mcp_mgr: &Arc<crate::mcp::McpClientManager>,
|
||||
tool_name: &str,
|
||||
args_json_str: &str,
|
||||
) -> Result<(String, String), String> {
|
||||
let args_value: Value =
|
||||
serde_json::from_str(args_json_str).map_err(|e| format!("parse tool args: {}", e))?;
|
||||
let args_obj = args_value.as_object().cloned();
|
||||
|
||||
let server_name = mcp_mgr
|
||||
.get_tool(tool_name)
|
||||
.map(|t| t.server)
|
||||
.ok_or_else(|| format!("tool not found: {}", tool_name))?;
|
||||
|
||||
let result = mcp_mgr
|
||||
.call_tool(tool_name, args_obj)
|
||||
.await
|
||||
.map_err(|e| format!("tool call failed: {}", e))?;
|
||||
|
||||
let output_str = serde_json::to_string(&result)
|
||||
.map_err(|e| format!("Failed to serialize tool result: {}", e))?;
|
||||
Ok((server_name, output_str))
|
||||
}
|
||||
|
||||
async fn resume_with_tool_result(&self, args: ResumeWithToolArgs<'_>) -> Result<Value, String> {
|
||||
let mut payload2 = args.original_payload.clone();
|
||||
let obj = payload2
|
||||
.as_object_mut()
|
||||
.ok_or_else(|| "payload not an object".to_string())?;
|
||||
|
||||
// Build function_call and tool result items per OpenAI Responses spec
|
||||
let user_item = serde_json::json!({
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": args.original_body.input.clone()
|
||||
});
|
||||
// temp system message since currently only support 1 turn of mcp function call
|
||||
let system_item = serde_json::json!({
|
||||
"type": "message",
|
||||
"role": "system",
|
||||
"content": "please resume with the following tool result, and answer user's question directly, don't trigger any more tool calls"
|
||||
});
|
||||
|
||||
let func_item = serde_json::json!({
|
||||
"type": "function_call",
|
||||
"call_id": args.call_id,
|
||||
"name": args.tool_name,
|
||||
"arguments": args.args_json_str
|
||||
});
|
||||
// Build tool result item as function_call_output per OpenAI Responses spec
|
||||
let tool_item = serde_json::json!({
|
||||
"type": "function_call_output",
|
||||
"call_id": args.call_id,
|
||||
"output": args.output_str
|
||||
});
|
||||
|
||||
obj.insert(
|
||||
"input".to_string(),
|
||||
Value::Array(vec![user_item, system_item, func_item, tool_item]),
|
||||
);
|
||||
|
||||
// Ensure non-streaming and no store to upstream
|
||||
obj.insert("stream".to_string(), Value::Bool(false));
|
||||
obj.insert("store".to_string(), Value::Bool(false));
|
||||
let mut req = self.client.post(args.url).json(&payload2);
|
||||
if let Some(headers) = args.headers {
|
||||
req = apply_request_headers(headers, req, true);
|
||||
}
|
||||
let resp = req
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("resume request failed: {}", e))?;
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("resume upstream error {}: {}", status, body));
|
||||
}
|
||||
let mut v = resp
|
||||
.json::<Value>()
|
||||
.await
|
||||
.map_err(|e| format!("parse resume response: {}", e))?;
|
||||
|
||||
if let Some(instr) = &args.original_body.instructions {
|
||||
if let Some(obj) = v.as_object_mut() {
|
||||
obj.entry("instructions")
|
||||
.or_insert(Value::String(instr.clone()));
|
||||
}
|
||||
}
|
||||
// After resume, mask tools as MCP if request used MCP
|
||||
Self::mask_tools_as_mcp(&mut v, args.original_body);
|
||||
if let Some(obj) = v.as_object_mut() {
|
||||
obj.insert("store".to_string(), Value::Bool(args.original_body.store));
|
||||
}
|
||||
Ok(v)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl super::super::RouterTrait for OpenAIRouter {
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
|
||||
@@ -644,27 +644,96 @@ async fn responses_handler(
|
||||
}))
|
||||
.into_response()
|
||||
} else {
|
||||
Json(json!({
|
||||
"id": format!("resp-{}", Uuid::new_v4()),
|
||||
"object": "response",
|
||||
"created_at": timestamp,
|
||||
"model": "mock-model",
|
||||
"output": [{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{
|
||||
"type": "output_text",
|
||||
"text": "This is a mock responses output."
|
||||
}]
|
||||
}],
|
||||
"status": "completed",
|
||||
"usage": {
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 5,
|
||||
"total_tokens": 15
|
||||
}
|
||||
}))
|
||||
.into_response()
|
||||
// If tools are provided and this is the first call (no previous_response_id),
|
||||
// emit a single function_tool_call to trigger the router's MCP flow.
|
||||
let has_tools = payload
|
||||
.get("tools")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter().any(|tool| {
|
||||
tool.get("type")
|
||||
.and_then(|t| t.as_str())
|
||||
.map(|t| t == "function")
|
||||
.unwrap_or(false)
|
||||
})
|
||||
})
|
||||
.unwrap_or(false);
|
||||
let has_function_output = payload
|
||||
.get("input")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|items| {
|
||||
items.iter().any(|item| {
|
||||
item.get("type")
|
||||
.and_then(|t| t.as_str())
|
||||
.map(|t| t == "function_call_output")
|
||||
.unwrap_or(false)
|
||||
})
|
||||
})
|
||||
.unwrap_or(false);
|
||||
|
||||
if has_tools && !has_function_output {
|
||||
let rid = format!("resp-{}", Uuid::new_v4());
|
||||
Json(json!({
|
||||
"id": rid,
|
||||
"object": "response",
|
||||
"created_at": timestamp,
|
||||
"model": "mock-model",
|
||||
"output": [{
|
||||
"type": "function_tool_call",
|
||||
"id": "call_1",
|
||||
"name": "brave_web_search",
|
||||
"arguments": "{\"query\":\"SGLang router MCP integration\"}",
|
||||
"status": "in_progress"
|
||||
}],
|
||||
"status": "in_progress",
|
||||
"usage": null
|
||||
}))
|
||||
.into_response()
|
||||
} else if has_tools && has_function_output {
|
||||
Json(json!({
|
||||
"id": format!("resp-{}", Uuid::new_v4()),
|
||||
"object": "response",
|
||||
"created_at": timestamp,
|
||||
"model": "mock-model",
|
||||
"output": [{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{
|
||||
"type": "output_text",
|
||||
"text": "Tool result consumed; here is the final answer."
|
||||
}]
|
||||
}],
|
||||
"status": "completed",
|
||||
"usage": {
|
||||
"input_tokens": 12,
|
||||
"output_tokens": 7,
|
||||
"total_tokens": 19
|
||||
}
|
||||
}))
|
||||
.into_response()
|
||||
} else {
|
||||
Json(json!({
|
||||
"id": format!("resp-{}", Uuid::new_v4()),
|
||||
"object": "response",
|
||||
"created_at": timestamp,
|
||||
"model": "mock-model",
|
||||
"output": [{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{
|
||||
"type": "output_text",
|
||||
"text": "This is a mock responses output."
|
||||
}]
|
||||
}],
|
||||
"status": "completed",
|
||||
"usage": {
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 5,
|
||||
"total_tokens": 15
|
||||
}
|
||||
}))
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,186 @@ use sglang_router_rs::protocols::spec::{
|
||||
ToolChoiceValue, Truncation, UsageInfo,
|
||||
};
|
||||
|
||||
mod common;
|
||||
use common::mock_mcp_server::MockMCPServer;
|
||||
use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType};
|
||||
use sglang_router_rs::config::{
|
||||
CircuitBreakerConfig, ConnectionMode, HealthCheckConfig, PolicyConfig, RetryConfig,
|
||||
RouterConfig, RoutingMode,
|
||||
};
|
||||
use sglang_router_rs::routers::RouterFactory;
|
||||
use sglang_router_rs::server::AppContext;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
|
||||
// Start mock MCP server
|
||||
let mut mcp = MockMCPServer::start().await.expect("start mcp");
|
||||
|
||||
// Write a temp MCP config file
|
||||
let mcp_yaml = format!(
|
||||
"servers:\n - name: mock\n protocol: streamable\n url: {}\n",
|
||||
mcp.url()
|
||||
);
|
||||
let dir = tempfile::tempdir().expect("tmpdir");
|
||||
let cfg_path = dir.path().join("mcp.yaml");
|
||||
std::fs::write(&cfg_path, mcp_yaml).expect("write mcp cfg");
|
||||
|
||||
// Start mock OpenAI worker
|
||||
let mut worker = MockWorker::new(MockWorkerConfig {
|
||||
port: 0,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 0.0,
|
||||
});
|
||||
let worker_url = worker.start().await.expect("start worker");
|
||||
|
||||
// Build router config (HTTP OpenAI mode)
|
||||
let router_cfg = RouterConfig {
|
||||
mode: RoutingMode::OpenAI {
|
||||
worker_urls: vec![worker_url],
|
||||
},
|
||||
connection_mode: ConnectionMode::Http,
|
||||
policy: PolicyConfig::Random,
|
||||
host: "127.0.0.1".to_string(),
|
||||
port: 0,
|
||||
max_payload_size: 8 * 1024 * 1024,
|
||||
request_timeout_secs: 60,
|
||||
worker_startup_timeout_secs: 5,
|
||||
worker_startup_check_interval_secs: 1,
|
||||
dp_aware: false,
|
||||
api_key: None,
|
||||
discovery: None,
|
||||
metrics: None,
|
||||
log_dir: None,
|
||||
log_level: Some("warn".to_string()),
|
||||
request_id_headers: None,
|
||||
max_concurrent_requests: 32,
|
||||
queue_size: 0,
|
||||
queue_timeout_secs: 5,
|
||||
rate_limit_tokens_per_second: None,
|
||||
cors_allowed_origins: vec![],
|
||||
retry: RetryConfig::default(),
|
||||
circuit_breaker: CircuitBreakerConfig::default(),
|
||||
disable_retries: false,
|
||||
disable_circuit_breaker: false,
|
||||
health_check: HealthCheckConfig::default(),
|
||||
enable_igw: false,
|
||||
model_path: None,
|
||||
tokenizer_path: None,
|
||||
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
|
||||
oracle: None,
|
||||
};
|
||||
|
||||
// Create router and context
|
||||
let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 64, None).expect("ctx");
|
||||
let router = RouterFactory::create_router(&Arc::new(ctx))
|
||||
.await
|
||||
.expect("router");
|
||||
|
||||
// Build a simple ResponsesRequest that will trigger the tool call
|
||||
let req = ResponsesRequest {
|
||||
background: false,
|
||||
include: None,
|
||||
input: ResponseInput::Text("search something".to_string()),
|
||||
instructions: Some("Be brief".to_string()),
|
||||
max_output_tokens: Some(64),
|
||||
max_tool_calls: None,
|
||||
metadata: None,
|
||||
model: Some("mock-model".to_string()),
|
||||
parallel_tool_calls: true,
|
||||
previous_response_id: None,
|
||||
reasoning: None,
|
||||
service_tier: sglang_router_rs::protocols::spec::ServiceTier::Auto,
|
||||
store: true,
|
||||
stream: false,
|
||||
temperature: Some(0.2),
|
||||
tool_choice: sglang_router_rs::protocols::spec::ToolChoice::default(),
|
||||
tools: vec![ResponseTool {
|
||||
r#type: ResponseToolType::Mcp,
|
||||
server_url: Some(mcp.url()),
|
||||
authorization: None,
|
||||
server_label: Some("mock".to_string()),
|
||||
server_description: None,
|
||||
require_approval: None,
|
||||
allowed_tools: None,
|
||||
}],
|
||||
top_logprobs: 0,
|
||||
top_p: None,
|
||||
truncation: sglang_router_rs::protocols::spec::Truncation::Disabled,
|
||||
user: None,
|
||||
request_id: "resp_test_mcp_e2e".to_string(),
|
||||
priority: 0,
|
||||
frequency_penalty: 0.0,
|
||||
presence_penalty: 0.0,
|
||||
stop: None,
|
||||
top_k: -1,
|
||||
min_p: 0.0,
|
||||
repetition_penalty: 1.0,
|
||||
};
|
||||
|
||||
let resp = router
|
||||
.route_responses(None, &req, req.model.as_deref())
|
||||
.await;
|
||||
|
||||
assert_eq!(resp.status(), axum::http::StatusCode::OK);
|
||||
|
||||
let body_bytes = axum::body::to_bytes(resp.into_body(), usize::MAX)
|
||||
.await
|
||||
.expect("Failed to read response body");
|
||||
let body_json: serde_json::Value =
|
||||
serde_json::from_slice(&body_bytes).expect("Failed to parse response JSON");
|
||||
|
||||
let output = body_json
|
||||
.get("output")
|
||||
.and_then(|v| v.as_array())
|
||||
.expect("response output missing");
|
||||
assert!(!output.is_empty(), "expected at least one output item");
|
||||
|
||||
let final_text = output
|
||||
.iter()
|
||||
.rev()
|
||||
.filter_map(|entry| entry.get("content"))
|
||||
.filter_map(|content| content.as_array())
|
||||
.flat_map(|parts| parts.iter())
|
||||
.filter_map(|part| part.get("text"))
|
||||
.filter_map(|v| v.as_str())
|
||||
.next();
|
||||
|
||||
if let Some(text) = final_text {
|
||||
assert_eq!(text, "Tool result consumed; here is the final answer.");
|
||||
} else {
|
||||
let call_entry = output.iter().find(|entry| {
|
||||
entry.get("type") == Some(&serde_json::Value::String("function_tool_call".into()))
|
||||
});
|
||||
assert!(call_entry.is_some(), "missing function tool call entry");
|
||||
if let Some(entry) = call_entry {
|
||||
assert_eq!(
|
||||
entry.get("status").and_then(|v| v.as_str()),
|
||||
Some("in_progress"),
|
||||
"function call should be in progress when no content is returned"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let tools = body_json
|
||||
.get("tools")
|
||||
.and_then(|v| v.as_array())
|
||||
.expect("tools array missing");
|
||||
assert_eq!(tools.len(), 1);
|
||||
let tool = tools.first().unwrap();
|
||||
assert_eq!(tool.get("type").and_then(|v| v.as_str()), Some("mcp"));
|
||||
assert_eq!(
|
||||
tool.get("server_label").and_then(|v| v.as_str()),
|
||||
Some("mock")
|
||||
);
|
||||
|
||||
// Cleanup
|
||||
worker.stop().await;
|
||||
mcp.stop().await;
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_responses_request_creation() {
|
||||
let request = ResponsesRequest {
|
||||
@@ -29,6 +209,7 @@ fn test_responses_request_creation() {
|
||||
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto),
|
||||
tools: vec![ResponseTool {
|
||||
r#type: ResponseToolType::WebSearchPreview,
|
||||
..Default::default()
|
||||
}],
|
||||
top_logprobs: 5,
|
||||
top_p: Some(0.9),
|
||||
@@ -179,6 +360,7 @@ fn test_json_serialization() {
|
||||
tool_choice: ToolChoice::Value(ToolChoiceValue::Required),
|
||||
tools: vec![ResponseTool {
|
||||
r#type: ResponseToolType::CodeInterpreter,
|
||||
..Default::default()
|
||||
}],
|
||||
top_logprobs: 10,
|
||||
top_p: Some(0.8),
|
||||
|
||||
Reference in New Issue
Block a user