[router] move to mcp sdk instead (#10057)
This commit is contained in:
@@ -55,6 +55,15 @@ tiktoken-rs = { version = "0.7.0" }
|
||||
minijinja = { version = "2.0" }
|
||||
rustls = { version = "0.23", default-features = false, features = ["ring", "std"] }
|
||||
hf-hub = { version = "0.4.3", features = ["tokio"] }
|
||||
rmcp = { version = "0.6.3", features = ["client", "server",
|
||||
"transport-child-process",
|
||||
"transport-sse-client-reqwest",
|
||||
"transport-streamable-http-client-reqwest",
|
||||
"transport-streamable-http-server",
|
||||
"transport-streamable-http-server-session",
|
||||
"reqwest",
|
||||
"auth"] }
|
||||
serde_yaml = "0.9"
|
||||
|
||||
# gRPC and Protobuf dependencies
|
||||
tonic = { version = "0.12", features = ["tls", "gzip", "transport"] }
|
||||
|
||||
535
sgl-router/src/mcp/client_manager.rs
Normal file
535
sgl-router/src/mcp/client_manager.rs
Normal file
@@ -0,0 +1,535 @@
|
||||
use backoff::ExponentialBackoffBuilder;
|
||||
use dashmap::DashMap;
|
||||
use rmcp::{
|
||||
model::{
|
||||
CallToolRequestParam, GetPromptRequestParam, GetPromptResult, Prompt,
|
||||
ReadResourceRequestParam, ReadResourceResult, Resource, Tool as McpTool,
|
||||
},
|
||||
service::RunningService,
|
||||
transport::{
|
||||
sse_client::SseClientConfig, streamable_http_client::StreamableHttpClientTransportConfig,
|
||||
ConfigureCommandExt, SseClientTransport, StreamableHttpClientTransport, TokioChildProcess,
|
||||
},
|
||||
RoleClient, ServiceExt,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{borrow::Cow, collections::HashMap, time::Duration};
|
||||
|
||||
use crate::mcp::{
|
||||
config::{McpConfig, McpServerConfig, McpTransport},
|
||||
error::{McpError, McpResult},
|
||||
};
|
||||
|
||||
/// Information about an available tool
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolInfo {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub server: String,
|
||||
pub parameters: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// Information about an available prompt
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PromptInfo {
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
pub server: String,
|
||||
pub arguments: Option<Vec<serde_json::Value>>,
|
||||
}
|
||||
|
||||
/// Information about an available resource
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ResourceInfo {
|
||||
pub uri: String,
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
pub mime_type: Option<String>,
|
||||
pub server: String,
|
||||
}
|
||||
|
||||
/// Manages MCP client connections and tool execution
|
||||
pub struct McpClientManager {
|
||||
/// Map of server_name -> MCP client
|
||||
clients: HashMap<String, RunningService<RoleClient, ()>>,
|
||||
/// Map of tool_name -> (server_name, tool_definition)
|
||||
tools: DashMap<String, (String, McpTool)>,
|
||||
/// Map of prompt_name -> (server_name, prompt_definition)
|
||||
prompts: DashMap<String, (String, Prompt)>,
|
||||
/// Map of resource_uri -> (server_name, resource_definition)
|
||||
resources: DashMap<String, (String, Resource)>,
|
||||
}
|
||||
|
||||
impl McpClientManager {
|
||||
/// Create a new manager and connect to all configured servers
|
||||
pub async fn new(config: McpConfig) -> McpResult<Self> {
|
||||
let mut mgr = Self {
|
||||
clients: HashMap::new(),
|
||||
tools: DashMap::new(),
|
||||
prompts: DashMap::new(),
|
||||
resources: DashMap::new(),
|
||||
};
|
||||
|
||||
for server_config in config.servers {
|
||||
match Self::connect_server(&server_config).await {
|
||||
Ok(client) => {
|
||||
mgr.load_server_inventory(&server_config.name, &client)
|
||||
.await;
|
||||
mgr.clients.insert(server_config.name.clone(), client);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
"Failed to connect to server '{}': {}",
|
||||
server_config.name,
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if mgr.clients.is_empty() {
|
||||
return Err(McpError::ConnectionFailed(
|
||||
"Failed to connect to any MCP servers".to_string(),
|
||||
));
|
||||
}
|
||||
Ok(mgr)
|
||||
}
|
||||
|
||||
/// Discover and cache tools/prompts/resources for a connected server
|
||||
async fn load_server_inventory(
|
||||
&self,
|
||||
server_name: &str,
|
||||
client: &RunningService<RoleClient, ()>,
|
||||
) {
|
||||
// Tools
|
||||
match client.peer().list_all_tools().await {
|
||||
Ok(ts) => {
|
||||
tracing::info!("Discovered {} tools from '{}'", ts.len(), server_name);
|
||||
for t in ts {
|
||||
if self.tools.contains_key(t.name.as_ref()) {
|
||||
tracing::warn!(
|
||||
"Tool '{}' from server '{}' is overwriting an existing tool.",
|
||||
&t.name,
|
||||
server_name
|
||||
);
|
||||
}
|
||||
self.tools
|
||||
.insert(t.name.to_string(), (server_name.to_string(), t));
|
||||
}
|
||||
}
|
||||
Err(e) => tracing::warn!("Failed to list tools from '{}': {}", server_name, e),
|
||||
}
|
||||
|
||||
// Prompts
|
||||
match client.peer().list_all_prompts().await {
|
||||
Ok(ps) => {
|
||||
tracing::info!("Discovered {} prompts from '{}'", ps.len(), server_name);
|
||||
for p in ps {
|
||||
if self.prompts.contains_key(&p.name) {
|
||||
tracing::warn!(
|
||||
"Prompt '{}' from server '{}' is overwriting an existing prompt.",
|
||||
&p.name,
|
||||
server_name
|
||||
);
|
||||
}
|
||||
self.prompts
|
||||
.insert(p.name.clone(), (server_name.to_string(), p));
|
||||
}
|
||||
}
|
||||
Err(e) => tracing::debug!("No prompts or failed to list on '{}': {}", server_name, e),
|
||||
}
|
||||
|
||||
// Resources
|
||||
match client.peer().list_all_resources().await {
|
||||
Ok(rs) => {
|
||||
tracing::info!("Discovered {} resources from '{}'", rs.len(), server_name);
|
||||
for r in rs {
|
||||
if self.resources.contains_key(&r.uri) {
|
||||
tracing::warn!(
|
||||
"Resource '{}' from server '{}' is overwriting an existing resource.",
|
||||
&r.uri,
|
||||
server_name
|
||||
);
|
||||
}
|
||||
self.resources
|
||||
.insert(r.uri.clone(), (server_name.to_string(), r));
|
||||
}
|
||||
}
|
||||
Err(e) => tracing::debug!("No resources or failed to list on '{}': {}", server_name, e),
|
||||
}
|
||||
}
|
||||
|
||||
/// Connect to a single MCP server with retry logic for remote transports
|
||||
async fn connect_server(config: &McpServerConfig) -> McpResult<RunningService<RoleClient, ()>> {
|
||||
let needs_retry = matches!(
|
||||
&config.transport,
|
||||
McpTransport::Sse { .. } | McpTransport::Streamable { .. }
|
||||
);
|
||||
if needs_retry {
|
||||
Self::connect_server_with_retry(config).await
|
||||
} else {
|
||||
Self::connect_server_impl(config).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Connect with exponential backoff retry for remote servers
|
||||
async fn connect_server_with_retry(
|
||||
config: &McpServerConfig,
|
||||
) -> McpResult<RunningService<RoleClient, ()>> {
|
||||
let backoff = ExponentialBackoffBuilder::new()
|
||||
.with_initial_interval(Duration::from_secs(1))
|
||||
.with_max_interval(Duration::from_secs(30))
|
||||
.with_max_elapsed_time(Some(Duration::from_secs(120)))
|
||||
.build();
|
||||
|
||||
backoff::future::retry(backoff, || async {
|
||||
match Self::connect_server_impl(config).await {
|
||||
Ok(client) => Ok(client),
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to connect to '{}', retrying: {}", config.name, e);
|
||||
Err(backoff::Error::transient(e))
|
||||
}
|
||||
}
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Internal implementation of server connection
|
||||
async fn connect_server_impl(
|
||||
config: &McpServerConfig,
|
||||
) -> McpResult<RunningService<RoleClient, ()>> {
|
||||
tracing::info!(
|
||||
"Connecting to MCP server '{}' via {:?}",
|
||||
config.name,
|
||||
config.transport
|
||||
);
|
||||
|
||||
match &config.transport {
|
||||
McpTransport::Stdio {
|
||||
command,
|
||||
args,
|
||||
envs,
|
||||
} => {
|
||||
let transport = TokioChildProcess::new(
|
||||
tokio::process::Command::new(command).configure(|cmd| {
|
||||
cmd.args(args)
|
||||
.envs(envs.iter())
|
||||
.stderr(std::process::Stdio::inherit());
|
||||
}),
|
||||
)
|
||||
.map_err(|e| McpError::Transport(format!("create stdio transport: {}", e)))?;
|
||||
|
||||
let client = ().serve(transport).await.map_err(|e| {
|
||||
McpError::ConnectionFailed(format!("initialize stdio client: {}", e))
|
||||
})?;
|
||||
|
||||
tracing::info!("Connected to stdio server '{}'", config.name);
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
McpTransport::Sse { url, token } => {
|
||||
let transport = if let Some(tok) = token {
|
||||
let client = reqwest::Client::builder()
|
||||
.default_headers({
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert(
|
||||
reqwest::header::AUTHORIZATION,
|
||||
format!("Bearer {}", tok).parse().map_err(|e| {
|
||||
McpError::Transport(format!("auth token: {}", e))
|
||||
})?,
|
||||
);
|
||||
headers
|
||||
})
|
||||
.build()
|
||||
.map_err(|e| McpError::Transport(format!("build HTTP client: {}", e)))?;
|
||||
|
||||
let cfg = SseClientConfig {
|
||||
sse_endpoint: url.clone().into(),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
SseClientTransport::start_with_client(client, cfg)
|
||||
.await
|
||||
.map_err(|e| McpError::Transport(format!("create SSE transport: {}", e)))?
|
||||
} else {
|
||||
SseClientTransport::start(url.as_str())
|
||||
.await
|
||||
.map_err(|e| McpError::Transport(format!("create SSE transport: {}", e)))?
|
||||
};
|
||||
|
||||
let client = ().serve(transport).await.map_err(|e| {
|
||||
McpError::ConnectionFailed(format!("initialize SSE client: {}", e))
|
||||
})?;
|
||||
|
||||
tracing::info!("Connected to SSE server '{}' at {}", config.name, url);
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
McpTransport::Streamable { url, token } => {
|
||||
let transport = if let Some(tok) = token {
|
||||
let mut cfg = StreamableHttpClientTransportConfig::with_uri(url.as_str());
|
||||
cfg.auth_header = Some(format!("Bearer {}", tok));
|
||||
StreamableHttpClientTransport::from_config(cfg)
|
||||
} else {
|
||||
StreamableHttpClientTransport::from_uri(url.as_str())
|
||||
};
|
||||
|
||||
let client = ().serve(transport).await.map_err(|e| {
|
||||
McpError::ConnectionFailed(format!("initialize streamable client: {}", e))
|
||||
})?;
|
||||
|
||||
tracing::info!(
|
||||
"Connected to streamable HTTP server '{}' at {}",
|
||||
config.name,
|
||||
url
|
||||
);
|
||||
Ok(client)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ===== Helpers =====
|
||||
|
||||
fn client_for(&self, server_name: &str) -> McpResult<&RunningService<RoleClient, ()>> {
|
||||
self.clients
|
||||
.get(server_name)
|
||||
.ok_or_else(|| McpError::ServerNotFound(server_name.to_string()))
|
||||
}
|
||||
|
||||
fn tool_entry(&self, name: &str) -> McpResult<(String, McpTool)> {
|
||||
self.tools
|
||||
.get(name)
|
||||
.map(|e| e.value().clone())
|
||||
.ok_or_else(|| McpError::ToolNotFound(name.to_string()))
|
||||
}
|
||||
|
||||
fn prompt_entry(&self, name: &str) -> McpResult<(String, Prompt)> {
|
||||
self.prompts
|
||||
.get(name)
|
||||
.map(|e| e.value().clone())
|
||||
.ok_or_else(|| McpError::PromptNotFound(name.to_string()))
|
||||
}
|
||||
|
||||
fn resource_entry(&self, uri: &str) -> McpResult<(String, Resource)> {
|
||||
self.resources
|
||||
.get(uri)
|
||||
.map(|e| e.value().clone())
|
||||
.ok_or_else(|| McpError::ResourceNotFound(uri.to_string()))
|
||||
}
|
||||
|
||||
// ===== Tool Methods =====
|
||||
|
||||
/// Call a tool by name
|
||||
pub async fn call_tool(
|
||||
&self,
|
||||
tool_name: &str,
|
||||
arguments: Option<serde_json::Map<String, serde_json::Value>>,
|
||||
) -> McpResult<rmcp::model::CallToolResult> {
|
||||
let (server_name, _tool) = self.tool_entry(tool_name)?;
|
||||
let client = self.client_for(&server_name)?;
|
||||
|
||||
tracing::debug!("Calling tool '{}' on '{}'", tool_name, server_name);
|
||||
|
||||
client
|
||||
.peer()
|
||||
.call_tool(CallToolRequestParam {
|
||||
name: Cow::Owned(tool_name.to_string()),
|
||||
arguments,
|
||||
})
|
||||
.await
|
||||
.map_err(|e| McpError::ToolExecution(format!("Tool call failed: {}", e)))
|
||||
}
|
||||
|
||||
/// Get all available tools
|
||||
pub fn list_tools(&self) -> Vec<ToolInfo> {
|
||||
self.tools
|
||||
.iter()
|
||||
.map(|entry| {
|
||||
let tool_name = entry.key().clone();
|
||||
let (server_name, tool) = entry.value();
|
||||
ToolInfo {
|
||||
name: tool_name,
|
||||
description: tool.description.as_deref().unwrap_or_default().to_string(),
|
||||
server: server_name.clone(),
|
||||
parameters: Some(serde_json::Value::Object((*tool.input_schema).clone())),
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get a specific tool by name
|
||||
pub fn get_tool(&self, name: &str) -> Option<ToolInfo> {
|
||||
self.tools.get(name).map(|entry| {
|
||||
let (server_name, tool) = entry.value();
|
||||
ToolInfo {
|
||||
name: name.to_string(),
|
||||
description: tool.description.as_deref().unwrap_or_default().to_string(),
|
||||
server: server_name.clone(),
|
||||
parameters: Some(serde_json::Value::Object((*tool.input_schema).clone())),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Check if a tool exists
|
||||
pub fn has_tool(&self, name: &str) -> bool {
|
||||
self.tools.contains_key(name)
|
||||
}
|
||||
|
||||
/// Get list of connected servers
|
||||
pub fn list_servers(&self) -> Vec<String> {
|
||||
self.clients.keys().cloned().collect()
|
||||
}
|
||||
|
||||
// ===== Prompt Methods =====
|
||||
|
||||
/// Get a prompt by name with arguments
|
||||
pub async fn get_prompt(
|
||||
&self,
|
||||
prompt_name: &str,
|
||||
arguments: Option<serde_json::Map<String, serde_json::Value>>,
|
||||
) -> McpResult<GetPromptResult> {
|
||||
let (server_name, _prompt) = self.prompt_entry(prompt_name)?;
|
||||
let client = self.client_for(&server_name)?;
|
||||
|
||||
tracing::debug!("Getting prompt '{}' from '{}'", prompt_name, server_name);
|
||||
|
||||
client
|
||||
.peer()
|
||||
.get_prompt(GetPromptRequestParam {
|
||||
name: prompt_name.to_string(),
|
||||
arguments,
|
||||
})
|
||||
.await
|
||||
.map_err(|e| McpError::ToolExecution(format!("Failed to get prompt: {}", e)))
|
||||
}
|
||||
|
||||
/// List all available prompts
|
||||
pub fn list_prompts(&self) -> Vec<PromptInfo> {
|
||||
self.prompts
|
||||
.iter()
|
||||
.map(|entry| {
|
||||
let name = entry.key().clone();
|
||||
let (server_name, prompt) = entry.value();
|
||||
PromptInfo {
|
||||
name,
|
||||
description: prompt.description.clone(),
|
||||
server: server_name.clone(),
|
||||
arguments: prompt
|
||||
.arguments
|
||||
.clone()
|
||||
.map(|args| args.into_iter().map(|arg| serde_json::json!(arg)).collect()),
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get a specific prompt info by name
|
||||
pub fn get_prompt_info(&self, name: &str) -> Option<PromptInfo> {
|
||||
self.prompts.get(name).map(|entry| {
|
||||
let (server_name, prompt) = entry.value();
|
||||
PromptInfo {
|
||||
name: name.to_string(),
|
||||
description: prompt.description.clone(),
|
||||
server: server_name.clone(),
|
||||
arguments: prompt
|
||||
.arguments
|
||||
.clone()
|
||||
.map(|args| args.into_iter().map(|arg| serde_json::json!(arg)).collect()),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// ===== Resource Methods =====
|
||||
|
||||
/// Read a resource by URI
|
||||
pub async fn read_resource(&self, uri: &str) -> McpResult<ReadResourceResult> {
|
||||
let (server_name, _resource) = self.resource_entry(uri)?;
|
||||
let client = self.client_for(&server_name)?;
|
||||
|
||||
tracing::debug!("Reading resource '{}' from '{}'", uri, server_name);
|
||||
|
||||
client
|
||||
.peer()
|
||||
.read_resource(ReadResourceRequestParam {
|
||||
uri: uri.to_string(),
|
||||
})
|
||||
.await
|
||||
.map_err(|e| McpError::ToolExecution(format!("Failed to read resource: {}", e)))
|
||||
}
|
||||
|
||||
/// List all available resources
|
||||
pub fn list_resources(&self) -> Vec<ResourceInfo> {
|
||||
self.resources
|
||||
.iter()
|
||||
.map(|entry| {
|
||||
let uri = entry.key().clone();
|
||||
let (server_name, resource) = entry.value();
|
||||
ResourceInfo {
|
||||
uri,
|
||||
name: resource.name.clone(),
|
||||
description: resource.description.clone(),
|
||||
mime_type: resource.mime_type.clone(),
|
||||
server: server_name.clone(),
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get a specific resource info by URI
|
||||
pub fn get_resource_info(&self, uri: &str) -> Option<ResourceInfo> {
|
||||
self.resources.get(uri).map(|entry| {
|
||||
let (server_name, resource) = entry.value();
|
||||
ResourceInfo {
|
||||
uri: uri.to_string(),
|
||||
name: resource.name.clone(),
|
||||
description: resource.description.clone(),
|
||||
mime_type: resource.mime_type.clone(),
|
||||
server: server_name.clone(),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Subscribe to resource changes
|
||||
pub async fn subscribe_resource(&self, uri: &str) -> McpResult<()> {
|
||||
let (server_name, _resource) = self.resource_entry(uri)?;
|
||||
let client = self.client_for(&server_name)?;
|
||||
|
||||
tracing::debug!("Subscribing to '{}' on '{}'", uri, server_name);
|
||||
|
||||
client
|
||||
.peer()
|
||||
.subscribe(rmcp::model::SubscribeRequestParam {
|
||||
uri: uri.to_string(),
|
||||
})
|
||||
.await
|
||||
.map_err(|e| McpError::ToolExecution(format!("Failed to subscribe: {}", e)))
|
||||
}
|
||||
|
||||
/// Unsubscribe from resource changes
|
||||
pub async fn unsubscribe_resource(&self, uri: &str) -> McpResult<()> {
|
||||
let (server_name, _resource) = self.resource_entry(uri)?;
|
||||
let client = self.client_for(&server_name)?;
|
||||
|
||||
tracing::debug!("Unsubscribing from '{}' on '{}'", uri, server_name);
|
||||
|
||||
client
|
||||
.peer()
|
||||
.unsubscribe(rmcp::model::UnsubscribeRequestParam {
|
||||
uri: uri.to_string(),
|
||||
})
|
||||
.await
|
||||
.map_err(|e| McpError::ToolExecution(format!("Failed to unsubscribe: {}", e)))
|
||||
}
|
||||
|
||||
/// Disconnect from all servers (for cleanup)
|
||||
pub async fn shutdown(&mut self) {
|
||||
for (name, client) in self.clients.drain() {
|
||||
if let Err(e) = client.cancel().await {
|
||||
tracing::warn!("Error disconnecting from '{}': {}", name, e);
|
||||
}
|
||||
}
|
||||
self.tools.clear();
|
||||
self.prompts.clear();
|
||||
self.resources.clear();
|
||||
}
|
||||
}
|
||||
52
sgl-router/src/mcp/config.rs
Normal file
52
sgl-router/src/mcp/config.rs
Normal file
@@ -0,0 +1,52 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct McpConfig {
|
||||
pub servers: Vec<McpServerConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct McpServerConfig {
|
||||
pub name: String,
|
||||
#[serde(flatten)]
|
||||
pub transport: McpTransport,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(tag = "protocol", rename_all = "lowercase")]
|
||||
pub enum McpTransport {
|
||||
Stdio {
|
||||
command: String,
|
||||
#[serde(default)]
|
||||
args: Vec<String>,
|
||||
#[serde(default)]
|
||||
envs: HashMap<String, String>,
|
||||
},
|
||||
Sse {
|
||||
url: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
token: Option<String>,
|
||||
},
|
||||
Streamable {
|
||||
url: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
token: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
impl McpConfig {
|
||||
/// Load configuration from a YAML file
|
||||
pub async fn from_file(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
let content = tokio::fs::read_to_string(path).await?;
|
||||
let config: Self = serde_yaml::from_str(&content)?;
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
/// Load configuration from environment variables (optional)
|
||||
pub fn from_env() -> Option<Self> {
|
||||
// This could be expanded to read from env vars
|
||||
// For now, return None to indicate env config not implemented
|
||||
None
|
||||
}
|
||||
}
|
||||
42
sgl-router/src/mcp/error.rs
Normal file
42
sgl-router/src/mcp/error.rs
Normal file
@@ -0,0 +1,42 @@
|
||||
use thiserror::Error;
|
||||
|
||||
pub type McpResult<T> = Result<T, McpError>;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum McpError {
|
||||
#[error("Server not found: {0}")]
|
||||
ServerNotFound(String),
|
||||
|
||||
#[error("Tool not found: {0}")]
|
||||
ToolNotFound(String),
|
||||
|
||||
#[error("Transport error: {0}")]
|
||||
Transport(String),
|
||||
|
||||
#[error("Tool execution failed: {0}")]
|
||||
ToolExecution(String),
|
||||
|
||||
#[error("Connection failed: {0}")]
|
||||
ConnectionFailed(String),
|
||||
|
||||
#[error("Configuration error: {0}")]
|
||||
Config(String),
|
||||
|
||||
#[error("Authentication error: {0}")]
|
||||
Auth(String),
|
||||
|
||||
#[error("Resource not found: {0}")]
|
||||
ResourceNotFound(String),
|
||||
|
||||
#[error("Prompt not found: {0}")]
|
||||
PromptNotFound(String),
|
||||
|
||||
#[error(transparent)]
|
||||
Sdk(#[from] Box<rmcp::RmcpError>),
|
||||
|
||||
#[error(transparent)]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
#[error(transparent)]
|
||||
Http(#[from] reqwest::Error),
|
||||
}
|
||||
@@ -1,9 +1,18 @@
|
||||
// mod.rs - MCP module exports
|
||||
pub mod tool_server;
|
||||
pub mod types;
|
||||
// MCP Client for SGLang Router
|
||||
//
|
||||
// This module provides a complete MCP (Model Context Protocol) client implementation
|
||||
// supporting multiple transport types (stdio, SSE, HTTP) and all MCP features:
|
||||
// - Tools: Discovery and execution
|
||||
// - Prompts: Reusable templates for LLM interactions
|
||||
// - Resources: File/data access with subscription support
|
||||
// - OAuth: Secure authentication for remote servers
|
||||
|
||||
pub use tool_server::{parse_sse_event, MCPToolServer, ToolStats};
|
||||
pub use types::{
|
||||
HttpConnection, MCPError, MCPResult, MultiToolSessionManager, SessionStats, ToolCall,
|
||||
ToolResult, ToolSession,
|
||||
};
|
||||
pub mod client_manager;
|
||||
pub mod config;
|
||||
pub mod error;
|
||||
pub mod oauth;
|
||||
|
||||
// Re-export the main types for convenience
|
||||
pub use client_manager::{McpClientManager, PromptInfo, ResourceInfo, ToolInfo};
|
||||
pub use config::{McpConfig, McpServerConfig, McpTransport};
|
||||
pub use error::{McpError, McpResult};
|
||||
|
||||
191
sgl-router/src/mcp/oauth.rs
Normal file
191
sgl-router/src/mcp/oauth.rs
Normal file
@@ -0,0 +1,191 @@
|
||||
// OAuth authentication support for MCP servers
|
||||
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
response::Html,
|
||||
routing::get,
|
||||
Router,
|
||||
};
|
||||
use rmcp::transport::auth::OAuthState;
|
||||
use serde::Deserialize;
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
use tokio::sync::{oneshot, Mutex};
|
||||
|
||||
use crate::mcp::error::{McpError, McpResult};
|
||||
|
||||
/// OAuth callback parameters
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct CallbackParams {
|
||||
code: String,
|
||||
#[allow(dead_code)]
|
||||
state: Option<String>,
|
||||
}
|
||||
|
||||
/// State for the callback server
|
||||
#[derive(Clone)]
|
||||
struct CallbackState {
|
||||
code_receiver: Arc<Mutex<Option<oneshot::Sender<String>>>>,
|
||||
}
|
||||
|
||||
/// HTML page returned after successful OAuth callback
|
||||
const CALLBACK_HTML: &str = r#"
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>OAuth Success</title>
|
||||
<style>
|
||||
body {
|
||||
font-family: Arial, sans-serif;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
height: 100vh;
|
||||
margin: 0;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
}
|
||||
.container {
|
||||
background: white;
|
||||
padding: 40px;
|
||||
border-radius: 10px;
|
||||
box-shadow: 0 10px 30px rgba(0,0,0,0.2);
|
||||
text-align: center;
|
||||
}
|
||||
h1 { color: #333; }
|
||||
p { color: #666; margin: 20px 0; }
|
||||
.success { color: #4CAF50; font-size: 48px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="success">✓</div>
|
||||
<h1>Authentication Successful!</h1>
|
||||
<p>You can now close this window and return to your application.</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"#;
|
||||
|
||||
/// OAuth authentication helper for MCP servers
|
||||
pub struct OAuthHelper {
|
||||
server_url: String,
|
||||
redirect_uri: String,
|
||||
callback_port: u16,
|
||||
}
|
||||
|
||||
impl OAuthHelper {
|
||||
/// Create a new OAuth helper
|
||||
pub fn new(server_url: String, redirect_uri: String, callback_port: u16) -> Self {
|
||||
Self {
|
||||
server_url,
|
||||
redirect_uri,
|
||||
callback_port,
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform OAuth authentication flow
|
||||
pub async fn authenticate(
|
||||
&self,
|
||||
scopes: &[&str],
|
||||
) -> McpResult<rmcp::transport::auth::AuthorizationManager> {
|
||||
// Initialize OAuth state machine
|
||||
let mut oauth_state = OAuthState::new(&self.server_url, None)
|
||||
.await
|
||||
.map_err(|e| McpError::Auth(format!("Failed to initialize OAuth: {}", e)))?;
|
||||
|
||||
oauth_state
|
||||
.start_authorization(scopes, &self.redirect_uri)
|
||||
.await
|
||||
.map_err(|e| McpError::Auth(format!("Failed to start authorization: {}", e)))?;
|
||||
|
||||
// Get authorization URL
|
||||
let auth_url = oauth_state
|
||||
.get_authorization_url()
|
||||
.await
|
||||
.map_err(|e| McpError::Auth(format!("Failed to get authorization URL: {}", e)))?;
|
||||
|
||||
tracing::info!("OAuth authorization URL: {}", auth_url);
|
||||
|
||||
// Start callback server and wait for code
|
||||
let auth_code = self.start_callback_server().await?;
|
||||
|
||||
// Exchange code for token
|
||||
oauth_state
|
||||
.handle_callback(&auth_code)
|
||||
.await
|
||||
.map_err(|e| McpError::Auth(format!("Failed to handle OAuth callback: {}", e)))?;
|
||||
|
||||
// Get authorization manager
|
||||
oauth_state
|
||||
.into_authorization_manager()
|
||||
.ok_or_else(|| McpError::Auth("Failed to get authorization manager".to_string()))
|
||||
}
|
||||
|
||||
/// Start a local HTTP server to receive the OAuth callback
|
||||
async fn start_callback_server(&self) -> McpResult<String> {
|
||||
let (code_sender, code_receiver) = oneshot::channel::<String>();
|
||||
|
||||
let state = CallbackState {
|
||||
code_receiver: Arc::new(Mutex::new(Some(code_sender))),
|
||||
};
|
||||
|
||||
// Create router for callback
|
||||
let app = Router::new()
|
||||
.route("/callback", get(Self::callback_handler))
|
||||
.with_state(state);
|
||||
|
||||
let addr = SocketAddr::from(([127, 0, 0, 1], self.callback_port));
|
||||
|
||||
// Start server in background
|
||||
let listener = tokio::net::TcpListener::bind(addr).await.map_err(|e| {
|
||||
McpError::Auth(format!(
|
||||
"Failed to bind to callback port {}: {}",
|
||||
self.callback_port, e
|
||||
))
|
||||
})?;
|
||||
|
||||
tokio::spawn(async move {
|
||||
let _ = axum::serve(listener, app).await;
|
||||
});
|
||||
|
||||
tracing::info!(
|
||||
"OAuth callback server started on port {}",
|
||||
self.callback_port
|
||||
);
|
||||
|
||||
// Wait for authorization code
|
||||
code_receiver
|
||||
.await
|
||||
.map_err(|_| McpError::Auth("Failed to receive authorization code".to_string()))
|
||||
}
|
||||
|
||||
/// Handle OAuth callback
|
||||
async fn callback_handler(
|
||||
Query(params): Query<CallbackParams>,
|
||||
State(state): State<CallbackState>,
|
||||
) -> Html<String> {
|
||||
tracing::debug!("Received OAuth callback with code");
|
||||
|
||||
// Send code to waiting task
|
||||
if let Some(sender) = state.code_receiver.lock().await.take() {
|
||||
let _ = sender.send(params.code);
|
||||
}
|
||||
|
||||
Html(CALLBACK_HTML.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an OAuth-authenticated client
|
||||
pub async fn create_oauth_client(
|
||||
server_url: String,
|
||||
_sse_url: String,
|
||||
redirect_uri: String,
|
||||
callback_port: u16,
|
||||
scopes: &[&str],
|
||||
) -> McpResult<rmcp::transport::auth::AuthClient<reqwest::Client>> {
|
||||
let helper = OAuthHelper::new(server_url, redirect_uri, callback_port);
|
||||
let auth_manager = helper.authenticate(scopes).await?;
|
||||
|
||||
let client = rmcp::transport::auth::AuthClient::new(reqwest::Client::default(), auth_manager);
|
||||
|
||||
Ok(client)
|
||||
}
|
||||
@@ -1,534 +0,0 @@
|
||||
// tool_server.rs - Main MCP implementation (matching Python's tool_server.py)
|
||||
use crate::mcp::types::*;
|
||||
use serde_json::{json, Value};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Main MCP Tool Server
|
||||
pub struct MCPToolServer {
|
||||
/// Tool descriptions by server
|
||||
tool_descriptions: HashMap<String, Value>,
|
||||
/// Server URLs
|
||||
urls: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl Default for MCPToolServer {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl MCPToolServer {
|
||||
/// Create new MCPToolServer
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
tool_descriptions: HashMap::new(),
|
||||
urls: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Clears all existing tool servers and adds new ones from the provided URL(s).
|
||||
/// URLs can be a single string or multiple comma-separated strings.
|
||||
pub async fn add_tool_server(&mut self, server_url: String) -> MCPResult<()> {
|
||||
let tool_urls: Vec<&str> = server_url.split(",").collect();
|
||||
let mut successful_connections = 0;
|
||||
let mut errors = Vec::new();
|
||||
|
||||
// Clear existing
|
||||
self.tool_descriptions = HashMap::new();
|
||||
self.urls = HashMap::new();
|
||||
|
||||
for url_str in tool_urls {
|
||||
let url_str = url_str.trim();
|
||||
|
||||
// Format URL for MCP-compliant connection
|
||||
let formatted_url = if url_str.starts_with("http://") || url_str.starts_with("https://")
|
||||
{
|
||||
url_str.to_string()
|
||||
} else {
|
||||
// Default to MCP endpoint if no protocol specified
|
||||
format!("http://{}", url_str)
|
||||
};
|
||||
|
||||
// Server connection with retry and error recovery
|
||||
match self.connect_to_server(&formatted_url).await {
|
||||
Ok((_init_response, tools_response)) => {
|
||||
// Process tools with validation
|
||||
let tools_obj = post_process_tools_description(tools_response);
|
||||
|
||||
// Tool storage with conflict detection
|
||||
for tool in &tools_obj.tools {
|
||||
let tool_name = &tool.name;
|
||||
|
||||
// Check for duplicate tools
|
||||
if self.tool_descriptions.contains_key(tool_name) {
|
||||
tracing::warn!(
|
||||
"Tool {} already exists. Ignoring duplicate tool from server {}",
|
||||
tool_name,
|
||||
formatted_url
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Store individual tool descriptions
|
||||
let tool_json = json!(tool);
|
||||
self.tool_descriptions
|
||||
.insert(tool_name.clone(), tool_json.clone());
|
||||
self.urls.insert(tool_name.clone(), formatted_url.clone());
|
||||
}
|
||||
|
||||
successful_connections += 1;
|
||||
}
|
||||
Err(e) => {
|
||||
errors.push(format!("Failed to connect to {}: {}", formatted_url, e));
|
||||
tracing::warn!("Failed to connect to MCP server {}: {}", formatted_url, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Error handling - succeed if at least one server connects
|
||||
if successful_connections == 0 {
|
||||
let combined_error = errors.join("; ");
|
||||
return Err(MCPError::ConnectionError(format!(
|
||||
"Failed to connect to any MCP servers: {}",
|
||||
combined_error
|
||||
)));
|
||||
}
|
||||
|
||||
if !errors.is_empty() {
|
||||
tracing::warn!("Some MCP servers failed to connect: {}", errors.join("; "));
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
"Successfully connected to {} MCP server(s), discovered {} tool(s)",
|
||||
successful_connections,
|
||||
self.tool_descriptions.len()
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Server connection with retries (internal helper)
|
||||
async fn connect_to_server(
|
||||
&self,
|
||||
url: &str,
|
||||
) -> MCPResult<(InitializeResponse, ListToolsResponse)> {
|
||||
const MAX_RETRIES: u32 = 3;
|
||||
const RETRY_DELAY_MS: u64 = 1000;
|
||||
|
||||
let mut last_error = None;
|
||||
|
||||
for attempt in 1..=MAX_RETRIES {
|
||||
match list_server_and_tools(url).await {
|
||||
Ok(result) => return Ok(result),
|
||||
Err(e) => {
|
||||
last_error = Some(e);
|
||||
if attempt < MAX_RETRIES {
|
||||
tracing::debug!(
|
||||
"MCP server connection attempt {}/{} failed for {}: {}. Retrying...",
|
||||
attempt,
|
||||
MAX_RETRIES,
|
||||
url,
|
||||
last_error.as_ref().unwrap()
|
||||
);
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(
|
||||
RETRY_DELAY_MS * attempt as u64,
|
||||
))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(last_error.unwrap())
|
||||
}
|
||||
|
||||
/// Check if tool exists (matching Python's has_tool)
|
||||
pub fn has_tool(&self, tool_name: &str) -> bool {
|
||||
self.tool_descriptions.contains_key(tool_name)
|
||||
}
|
||||
|
||||
/// Get tool description (matching Python's get_tool_description)
|
||||
pub fn get_tool_description(&self, tool_name: &str) -> Option<&Value> {
|
||||
self.tool_descriptions.get(tool_name)
|
||||
}
|
||||
|
||||
/// Get tool session (matching Python's get_tool_session)
|
||||
pub async fn get_tool_session(&self, tool_name: &str) -> MCPResult<ToolSession> {
|
||||
let url = self
|
||||
.urls
|
||||
.get(tool_name)
|
||||
.ok_or_else(|| MCPError::ToolNotFound(tool_name.to_string()))?;
|
||||
|
||||
// Create session
|
||||
ToolSession::new(url.clone()).await
|
||||
}
|
||||
|
||||
/// Create multi-tool session manager
|
||||
pub async fn create_multi_tool_session(
|
||||
&self,
|
||||
tool_names: Vec<String>,
|
||||
) -> MCPResult<MultiToolSessionManager> {
|
||||
let mut session_manager = MultiToolSessionManager::new();
|
||||
|
||||
// Group tools by server URL for efficient session creation
|
||||
let mut server_tools: std::collections::HashMap<String, Vec<String>> =
|
||||
std::collections::HashMap::new();
|
||||
|
||||
for tool_name in tool_names {
|
||||
if let Some(url) = self.urls.get(&tool_name) {
|
||||
server_tools.entry(url.clone()).or_default().push(tool_name);
|
||||
} else {
|
||||
return Err(MCPError::ToolNotFound(format!(
|
||||
"Tool not found: {}",
|
||||
tool_name
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// Create sessions for each server
|
||||
for (server_url, tools) in server_tools {
|
||||
session_manager
|
||||
.add_tools_from_server(server_url, tools)
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(session_manager)
|
||||
}
|
||||
|
||||
/// List all available tools
|
||||
pub fn list_tools(&self) -> Vec<String> {
|
||||
self.tool_descriptions.keys().cloned().collect()
|
||||
}
|
||||
|
||||
/// Get tool statistics
|
||||
pub fn get_tool_stats(&self) -> ToolStats {
|
||||
ToolStats {
|
||||
total_tools: self.tool_descriptions.len(),
|
||||
total_servers: self
|
||||
.urls
|
||||
.values()
|
||||
.collect::<std::collections::HashSet<_>>()
|
||||
.len(),
|
||||
}
|
||||
}
|
||||
|
||||
/// List all connected servers
|
||||
pub fn list_servers(&self) -> Vec<String> {
|
||||
self.urls
|
||||
.values()
|
||||
.cloned()
|
||||
.collect::<std::collections::HashSet<_>>()
|
||||
.into_iter()
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Check if a specific server is connected
|
||||
pub fn has_server(&self, server_url: &str) -> bool {
|
||||
self.urls.values().any(|url| url == server_url)
|
||||
}
|
||||
|
||||
/// Execute a tool directly (convenience method for simple usage)
|
||||
pub async fn call_tool(
|
||||
&self,
|
||||
tool_name: &str,
|
||||
arguments: serde_json::Value,
|
||||
) -> MCPResult<serde_json::Value> {
|
||||
let session = self.get_tool_session(tool_name).await?;
|
||||
session.call_tool(tool_name, arguments).await
|
||||
}
|
||||
|
||||
/// Create a tool session from server URL (convenience method)
|
||||
pub async fn create_session_from_url(&self, server_url: &str) -> MCPResult<ToolSession> {
|
||||
ToolSession::new(server_url.to_string()).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Tool statistics for monitoring
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ToolStats {
|
||||
pub total_tools: usize,
|
||||
pub total_servers: usize,
|
||||
}
|
||||
|
||||
/// MCP-compliant server connection using JSON-RPC over SSE
|
||||
async fn list_server_and_tools(
|
||||
server_url: &str,
|
||||
) -> MCPResult<(InitializeResponse, ListToolsResponse)> {
|
||||
// MCP specification:
|
||||
// 1. Connect to MCP endpoint with GET (SSE) or POST (JSON-RPC)
|
||||
// 2. Send initialize request
|
||||
// 3. Send tools/list request
|
||||
// 4. Parse JSON-RPC responses
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
// Step 1: Send initialize request
|
||||
let init_request = MCPRequest {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id: "1".to_string(),
|
||||
method: "initialize".to_string(),
|
||||
params: Some(json!({
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {}
|
||||
})),
|
||||
};
|
||||
|
||||
let init_response = send_mcp_request(&client, server_url, init_request).await?;
|
||||
let init_result: InitializeResponse = serde_json::from_value(init_response).map_err(|e| {
|
||||
MCPError::SerializationError(format!("Failed to parse initialize response: {}", e))
|
||||
})?;
|
||||
|
||||
// Step 2: Send tools/list request
|
||||
let tools_request = MCPRequest {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id: "2".to_string(),
|
||||
method: "tools/list".to_string(),
|
||||
params: Some(json!({})),
|
||||
};
|
||||
|
||||
let tools_response = send_mcp_request(&client, server_url, tools_request).await?;
|
||||
let tools_result: ListToolsResponse = serde_json::from_value(tools_response).map_err(|e| {
|
||||
MCPError::SerializationError(format!("Failed to parse tools/list response: {}", e))
|
||||
})?;
|
||||
|
||||
Ok((init_result, tools_result))
|
||||
}
|
||||
|
||||
/// Send MCP JSON-RPC request (supports both HTTP POST and SSE)
|
||||
async fn send_mcp_request(
|
||||
client: &reqwest::Client,
|
||||
url: &str,
|
||||
request: MCPRequest,
|
||||
) -> MCPResult<Value> {
|
||||
// Use HTTP POST for JSON-RPC requests
|
||||
let response = client
|
||||
.post(url)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Accept", "application/json")
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| MCPError::ConnectionError(format!("MCP request failed: {}", e)))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(MCPError::ProtocolError(format!(
|
||||
"HTTP {}",
|
||||
response.status()
|
||||
)));
|
||||
}
|
||||
|
||||
let mcp_response: MCPResponse = response.json().await.map_err(|e| {
|
||||
MCPError::SerializationError(format!("Failed to parse MCP response: {}", e))
|
||||
})?;
|
||||
|
||||
if let Some(error) = mcp_response.error {
|
||||
return Err(MCPError::ProtocolError(format!(
|
||||
"MCP error: {}",
|
||||
error.message
|
||||
)));
|
||||
}
|
||||
|
||||
mcp_response
|
||||
.result
|
||||
.ok_or_else(|| MCPError::ProtocolError("No result in MCP response".to_string()))
|
||||
}
|
||||
|
||||
// Removed old send_http_request - now using send_mcp_request with proper MCP protocol
|
||||
|
||||
/// Parse SSE event format (MCP-compliant JSON-RPC only)
|
||||
pub fn parse_sse_event(event: &str) -> MCPResult<Option<Value>> {
|
||||
let mut data_lines = Vec::new();
|
||||
|
||||
for line in event.lines() {
|
||||
if let Some(stripped) = line.strip_prefix("data: ") {
|
||||
data_lines.push(stripped);
|
||||
}
|
||||
}
|
||||
|
||||
if data_lines.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let json_data = data_lines.join("\n");
|
||||
if json_data.trim().is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Parse as MCP JSON-RPC response only (no custom events)
|
||||
let mcp_response: MCPResponse = serde_json::from_str(&json_data).map_err(|e| {
|
||||
MCPError::SerializationError(format!(
|
||||
"Failed to parse JSON-RPC response: {} - Data: {}",
|
||||
e, json_data
|
||||
))
|
||||
})?;
|
||||
|
||||
if let Some(error) = mcp_response.error {
|
||||
return Err(MCPError::ProtocolError(error.message));
|
||||
}
|
||||
|
||||
Ok(mcp_response.result)
|
||||
}
|
||||
|
||||
/// Schema adaptation matching Python's trim_schema()
|
||||
fn trim_schema(schema: &mut Value) {
|
||||
if let Some(obj) = schema.as_object_mut() {
|
||||
// Remove title and null defaults
|
||||
obj.remove("title");
|
||||
if obj.get("default") == Some(&Value::Null) {
|
||||
obj.remove("default");
|
||||
}
|
||||
|
||||
// Convert anyOf to type arrays
|
||||
if let Some(any_of) = obj.remove("anyOf") {
|
||||
if let Some(array) = any_of.as_array() {
|
||||
let types: Vec<String> = array
|
||||
.iter()
|
||||
.filter_map(|item| {
|
||||
item.get("type")
|
||||
.and_then(|t| t.as_str())
|
||||
.filter(|t| *t != "null")
|
||||
.map(|t| t.to_string())
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Handle single type vs array of types
|
||||
match types.len() {
|
||||
0 => {} // No valid types found
|
||||
1 => {
|
||||
obj.insert("type".to_string(), json!(types[0]));
|
||||
}
|
||||
_ => {
|
||||
obj.insert("type".to_string(), json!(types));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle oneOf similar to anyOf
|
||||
if let Some(one_of) = obj.remove("oneOf") {
|
||||
if let Some(array) = one_of.as_array() {
|
||||
let types: Vec<String> = array
|
||||
.iter()
|
||||
.filter_map(|item| {
|
||||
item.get("type")
|
||||
.and_then(|t| t.as_str())
|
||||
.filter(|t| *t != "null")
|
||||
.map(|t| t.to_string())
|
||||
})
|
||||
.collect();
|
||||
|
||||
if !types.is_empty() {
|
||||
obj.insert("type".to_string(), json!(types));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Recursive processing for properties
|
||||
if let Some(properties) = obj.get_mut("properties") {
|
||||
if let Some(props_obj) = properties.as_object_mut() {
|
||||
for (_, value) in props_obj.iter_mut() {
|
||||
trim_schema(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle nested schemas in items (for arrays)
|
||||
if let Some(items) = obj.get_mut("items") {
|
||||
trim_schema(items);
|
||||
}
|
||||
|
||||
// Handle nested schemas in additionalProperties
|
||||
if let Some(additional_props) = obj.get_mut("additionalProperties") {
|
||||
if additional_props.is_object() {
|
||||
trim_schema(additional_props);
|
||||
}
|
||||
}
|
||||
|
||||
// Handle patternProperties (for dynamic property names)
|
||||
if let Some(pattern_props) = obj.get_mut("patternProperties") {
|
||||
if let Some(pattern_obj) = pattern_props.as_object_mut() {
|
||||
for (_, value) in pattern_obj.iter_mut() {
|
||||
trim_schema(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle allOf in nested contexts
|
||||
if let Some(all_of) = obj.get_mut("allOf") {
|
||||
if let Some(array) = all_of.as_array_mut() {
|
||||
for item in array.iter_mut() {
|
||||
trim_schema(item);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Tool processing with filtering
|
||||
fn post_process_tools_description(mut tools_response: ListToolsResponse) -> ListToolsResponse {
|
||||
// Adapt schemas for Harmony
|
||||
for tool in &mut tools_response.tools {
|
||||
trim_schema(&mut tool.input_schema);
|
||||
}
|
||||
|
||||
// Tool filtering based on annotations
|
||||
let initial_count = tools_response.tools.len();
|
||||
|
||||
tools_response.tools.retain(|tool| {
|
||||
// Check include_in_prompt annotation (Python behavior)
|
||||
let include_in_prompt = tool
|
||||
.annotations
|
||||
.as_ref()
|
||||
.and_then(|a| a.get("include_in_prompt"))
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(true);
|
||||
|
||||
if !include_in_prompt {
|
||||
tracing::debug!(
|
||||
"Filtering out tool '{}' due to include_in_prompt=false",
|
||||
tool.name
|
||||
);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if tool is explicitly disabled
|
||||
let disabled = tool
|
||||
.annotations
|
||||
.as_ref()
|
||||
.and_then(|a| a.get("disabled"))
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false);
|
||||
|
||||
if disabled {
|
||||
tracing::debug!("Filtering out disabled tool '{}'", tool.name);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Validate tool has required fields
|
||||
if tool.name.trim().is_empty() {
|
||||
tracing::warn!("Filtering out tool with empty name");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check for valid input schema
|
||||
if tool.input_schema.is_null() {
|
||||
tracing::warn!("Tool '{}' has null input schema, but keeping it", tool.name);
|
||||
}
|
||||
|
||||
true
|
||||
});
|
||||
|
||||
let filtered_count = tools_response.tools.len();
|
||||
if filtered_count != initial_count {
|
||||
tracing::info!(
|
||||
"Filtered tools: {} -> {} ({} removed)",
|
||||
initial_count,
|
||||
filtered_count,
|
||||
initial_count - filtered_count
|
||||
);
|
||||
}
|
||||
|
||||
tools_response
|
||||
}
|
||||
|
||||
// Tests moved to tests/mcp_comprehensive_test.rs for better organization
|
||||
@@ -1,345 +0,0 @@
|
||||
// types.rs - All MCP data structures
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use thiserror::Error;
|
||||
use uuid;
|
||||
|
||||
// ===== Errors =====
|
||||
#[derive(Error, Debug)]
|
||||
pub enum MCPError {
|
||||
#[error("Connection failed: {0}")]
|
||||
ConnectionError(String),
|
||||
#[error("Invalid URL: {0}")]
|
||||
InvalidURL(String),
|
||||
#[error("Protocol error: {0}")]
|
||||
ProtocolError(String),
|
||||
#[error("Tool execution failed: {0}")]
|
||||
ToolExecutionError(String),
|
||||
#[error("Tool not found: {0}")]
|
||||
ToolNotFound(String),
|
||||
#[error("Serialization error: {0}")]
|
||||
SerializationError(String),
|
||||
#[error("Configuration error: {0}")]
|
||||
ConfigurationError(String),
|
||||
}
|
||||
|
||||
pub type MCPResult<T> = Result<T, MCPError>;
|
||||
|
||||
// Add From implementations for common error types
|
||||
impl From<serde_json::Error> for MCPError {
|
||||
fn from(err: serde_json::Error) -> Self {
|
||||
MCPError::SerializationError(err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<reqwest::Error> for MCPError {
|
||||
fn from(err: reqwest::Error) -> Self {
|
||||
MCPError::ConnectionError(err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
// ===== MCP Protocol Types =====
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MCPRequest {
|
||||
pub jsonrpc: String,
|
||||
pub id: String,
|
||||
pub method: String,
|
||||
pub params: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MCPResponse {
|
||||
pub jsonrpc: String,
|
||||
pub id: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub result: Option<serde_json::Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub error: Option<MCPErrorResponse>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MCPErrorResponse {
|
||||
pub code: i32,
|
||||
pub message: String,
|
||||
pub data: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
// ===== MCP Server Response Types =====
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct InitializeResponse {
|
||||
#[serde(rename = "serverInfo")]
|
||||
pub server_info: ServerInfo,
|
||||
pub instructions: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ServerInfo {
|
||||
pub name: String,
|
||||
pub version: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ListToolsResponse {
|
||||
pub tools: Vec<ToolInfo>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolInfo {
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
#[serde(rename = "inputSchema")]
|
||||
pub input_schema: serde_json::Value,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub annotations: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
// ===== Types =====
|
||||
pub type ToolCall = serde_json::Value; // Python uses dict
|
||||
pub type ToolResult = serde_json::Value; // Python uses dict
|
||||
|
||||
// ===== Connection Types =====
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HttpConnection {
|
||||
pub url: String,
|
||||
}
|
||||
|
||||
// ===== Tool Session =====
|
||||
pub struct ToolSession {
|
||||
pub connection: HttpConnection,
|
||||
pub client: reqwest::Client,
|
||||
pub session_initialized: bool,
|
||||
}
|
||||
|
||||
impl ToolSession {
|
||||
pub async fn new(connection_str: String) -> MCPResult<Self> {
|
||||
if !connection_str.starts_with("http://") && !connection_str.starts_with("https://") {
|
||||
return Err(MCPError::InvalidURL(format!(
|
||||
"Only HTTP/HTTPS URLs are supported: {}",
|
||||
connection_str
|
||||
)));
|
||||
}
|
||||
|
||||
let mut session = Self {
|
||||
connection: HttpConnection {
|
||||
url: connection_str,
|
||||
},
|
||||
client: reqwest::Client::new(),
|
||||
session_initialized: false,
|
||||
};
|
||||
|
||||
// Initialize the session
|
||||
session.initialize().await?;
|
||||
Ok(session)
|
||||
}
|
||||
|
||||
pub async fn new_http(url: String) -> MCPResult<Self> {
|
||||
Self::new(url).await
|
||||
}
|
||||
|
||||
/// Initialize the session
|
||||
pub async fn initialize(&mut self) -> MCPResult<()> {
|
||||
if self.session_initialized {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let init_request = MCPRequest {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id: "init".to_string(),
|
||||
method: "initialize".to_string(),
|
||||
params: Some(serde_json::json!({
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {}
|
||||
})),
|
||||
};
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&self.connection.url)
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&init_request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| MCPError::ConnectionError(format!("Initialize failed: {}", e)))?;
|
||||
|
||||
let mcp_response: MCPResponse = response.json().await.map_err(|e| {
|
||||
MCPError::SerializationError(format!("Failed to parse initialize response: {}", e))
|
||||
})?;
|
||||
|
||||
if let Some(error) = mcp_response.error {
|
||||
return Err(MCPError::ProtocolError(format!(
|
||||
"Initialize error: {}",
|
||||
error.message
|
||||
)));
|
||||
}
|
||||
|
||||
self.session_initialized = true;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Call a tool using MCP tools/call
|
||||
pub async fn call_tool(
|
||||
&self,
|
||||
name: &str,
|
||||
arguments: serde_json::Value,
|
||||
) -> MCPResult<serde_json::Value> {
|
||||
if !self.session_initialized {
|
||||
return Err(MCPError::ProtocolError(
|
||||
"Session not initialized. Call initialize() first.".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
use serde_json::json;
|
||||
|
||||
let request = MCPRequest {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id: format!("call_{}", uuid::Uuid::new_v4()),
|
||||
method: "tools/call".to_string(),
|
||||
params: Some(json!({
|
||||
"name": name,
|
||||
"arguments": arguments
|
||||
})),
|
||||
};
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&self.connection.url)
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| MCPError::ConnectionError(format!("Tool call failed: {}", e)))?;
|
||||
|
||||
let mcp_response: MCPResponse = response.json().await.map_err(|e| {
|
||||
MCPError::SerializationError(format!("Failed to parse tool response: {}", e))
|
||||
})?;
|
||||
|
||||
if let Some(error) = mcp_response.error {
|
||||
return Err(MCPError::ToolExecutionError(format!(
|
||||
"Tool '{}' failed: {}",
|
||||
name, error.message
|
||||
)));
|
||||
}
|
||||
|
||||
mcp_response
|
||||
.result
|
||||
.ok_or_else(|| MCPError::ProtocolError("No result in tool response".to_string()))
|
||||
}
|
||||
|
||||
/// Check if session is ready for tool calls
|
||||
pub fn is_ready(&self) -> bool {
|
||||
self.session_initialized
|
||||
}
|
||||
|
||||
/// Get connection info
|
||||
pub fn connection_info(&self) -> String {
|
||||
format!("HTTP: {}", self.connection.url)
|
||||
}
|
||||
}
|
||||
|
||||
// ===== Multi-Tool Session Manager =====
|
||||
pub struct MultiToolSessionManager {
|
||||
sessions: HashMap<String, ToolSession>, // server_url -> session
|
||||
tool_to_server: HashMap<String, String>, // tool_name -> server_url mapping
|
||||
}
|
||||
|
||||
impl Default for MultiToolSessionManager {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl MultiToolSessionManager {
|
||||
/// Create new multi-tool session manager
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
sessions: HashMap::new(),
|
||||
tool_to_server: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add tools from an MCP server (optimized to share sessions per server)
|
||||
pub async fn add_tools_from_server(
|
||||
&mut self,
|
||||
server_url: String,
|
||||
tool_names: Vec<String>,
|
||||
) -> MCPResult<()> {
|
||||
// Create one session per server URL (if not already exists)
|
||||
if !self.sessions.contains_key(&server_url) {
|
||||
let session = ToolSession::new(server_url.clone()).await?;
|
||||
self.sessions.insert(server_url.clone(), session);
|
||||
}
|
||||
|
||||
// Map all tools to this server URL
|
||||
for tool_name in tool_names {
|
||||
self.tool_to_server.insert(tool_name, server_url.clone());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get session for a specific tool
|
||||
pub fn get_session(&self, tool_name: &str) -> Option<&ToolSession> {
|
||||
let server_url = self.tool_to_server.get(tool_name)?;
|
||||
self.sessions.get(server_url)
|
||||
}
|
||||
|
||||
/// Execute tool with automatic session management
|
||||
pub async fn call_tool(
|
||||
&self,
|
||||
tool_name: &str,
|
||||
arguments: serde_json::Value,
|
||||
) -> MCPResult<serde_json::Value> {
|
||||
let server_url = self
|
||||
.tool_to_server
|
||||
.get(tool_name)
|
||||
.ok_or_else(|| MCPError::ToolNotFound(format!("No mapping for tool: {}", tool_name)))?;
|
||||
|
||||
let session = self.sessions.get(server_url).ok_or_else(|| {
|
||||
MCPError::ToolNotFound(format!("No session for server: {}", server_url))
|
||||
})?;
|
||||
|
||||
session.call_tool(tool_name, arguments).await
|
||||
}
|
||||
|
||||
/// Execute multiple tools concurrently
|
||||
pub async fn call_tools_concurrent(
|
||||
&self,
|
||||
tool_calls: Vec<(String, serde_json::Value)>,
|
||||
) -> Vec<MCPResult<serde_json::Value>> {
|
||||
let futures: Vec<_> = tool_calls
|
||||
.into_iter()
|
||||
.map(|(tool_name, args)| async move { self.call_tool(&tool_name, args).await })
|
||||
.collect();
|
||||
|
||||
futures::future::join_all(futures).await
|
||||
}
|
||||
|
||||
/// Get all available tool names
|
||||
pub fn list_tools(&self) -> Vec<String> {
|
||||
self.tool_to_server.keys().cloned().collect()
|
||||
}
|
||||
|
||||
/// Check if tool is available
|
||||
pub fn has_tool(&self, tool_name: &str) -> bool {
|
||||
self.tool_to_server.contains_key(tool_name)
|
||||
}
|
||||
|
||||
/// Get session statistics
|
||||
pub fn session_stats(&self) -> SessionStats {
|
||||
let total_sessions = self.sessions.len();
|
||||
let ready_sessions = self.sessions.values().filter(|s| s.is_ready()).count();
|
||||
let unique_servers = self.sessions.len(); // Now sessions = servers
|
||||
|
||||
SessionStats {
|
||||
total_sessions,
|
||||
ready_sessions,
|
||||
unique_servers,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SessionStats {
|
||||
pub total_sessions: usize,
|
||||
pub ready_sessions: usize,
|
||||
pub unique_servers: usize,
|
||||
}
|
||||
@@ -1,9 +1,14 @@
|
||||
// tests/common/mock_mcp_server.rs - Mock MCP server for testing
|
||||
|
||||
use axum::{
|
||||
extract::Json, http::StatusCode, response::Json as ResponseJson, routing::post, Router,
|
||||
use rmcp::{
|
||||
handler::server::{router::tool::ToolRouter, wrapper::Parameters},
|
||||
model::*,
|
||||
service::RequestContext,
|
||||
tool, tool_handler, tool_router,
|
||||
transport::streamable_http_server::{
|
||||
session::local::LocalSessionManager, StreamableHttpService,
|
||||
},
|
||||
ErrorData as McpError, RoleServer, ServerHandler,
|
||||
};
|
||||
use serde_json::{json, Value};
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
/// Mock MCP server that returns hardcoded responses for testing
|
||||
@@ -12,6 +17,69 @@ pub struct MockMCPServer {
|
||||
pub server_handle: Option<tokio::task::JoinHandle<()>>,
|
||||
}
|
||||
|
||||
/// Simple test server with mock search tools
|
||||
#[derive(Clone)]
|
||||
pub struct MockSearchServer {
|
||||
tool_router: ToolRouter<MockSearchServer>,
|
||||
}
|
||||
|
||||
#[tool_router]
|
||||
impl MockSearchServer {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
tool_router: Self::tool_router(),
|
||||
}
|
||||
}
|
||||
|
||||
#[tool(description = "Mock web search tool")]
|
||||
fn brave_web_search(
|
||||
&self,
|
||||
Parameters(params): Parameters<serde_json::Map<String, serde_json::Value>>,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
let query = params
|
||||
.get("query")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("test");
|
||||
Ok(CallToolResult::success(vec![Content::text(format!(
|
||||
"Mock search results for: {}",
|
||||
query
|
||||
))]))
|
||||
}
|
||||
|
||||
#[tool(description = "Mock local search tool")]
|
||||
fn brave_local_search(
|
||||
&self,
|
||||
Parameters(_params): Parameters<serde_json::Map<String, serde_json::Value>>,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
Ok(CallToolResult::success(vec![Content::text(
|
||||
"Mock local search results",
|
||||
)]))
|
||||
}
|
||||
}
|
||||
|
||||
#[tool_handler]
|
||||
impl ServerHandler for MockSearchServer {
|
||||
fn get_info(&self) -> ServerInfo {
|
||||
ServerInfo {
|
||||
protocol_version: ProtocolVersion::V_2024_11_05,
|
||||
capabilities: ServerCapabilities::builder().enable_tools().build(),
|
||||
server_info: Implementation {
|
||||
name: "Mock MCP Server".to_string(),
|
||||
version: "1.0.0".to_string(),
|
||||
},
|
||||
instructions: Some("Mock server for testing".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
async fn initialize(
|
||||
&self,
|
||||
_request: InitializeRequestParam,
|
||||
_context: RequestContext<RoleServer>,
|
||||
) -> Result<InitializeResult, McpError> {
|
||||
Ok(self.get_info())
|
||||
}
|
||||
}
|
||||
|
||||
impl MockMCPServer {
|
||||
/// Start a mock MCP server on an available port
|
||||
pub async fn start() -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
|
||||
@@ -19,7 +87,14 @@ impl MockMCPServer {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await?;
|
||||
let port = listener.local_addr()?.port();
|
||||
|
||||
let app = Router::new().route("/mcp", post(handle_mcp_request));
|
||||
// Create the MCP service using rmcp's StreamableHttpService
|
||||
let service = StreamableHttpService::new(
|
||||
|| Ok(MockSearchServer::new()),
|
||||
LocalSessionManager::default().into(),
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
let app = axum::Router::new().nest_service("/mcp", service);
|
||||
|
||||
let server_handle = tokio::spawn(async move {
|
||||
axum::serve(listener, app)
|
||||
@@ -59,142 +134,10 @@ impl Drop for MockMCPServer {
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle MCP requests and return mock responses
|
||||
async fn handle_mcp_request(Json(request): Json<Value>) -> Result<ResponseJson<Value>, StatusCode> {
|
||||
// Parse the JSON-RPC request
|
||||
let method = request.get("method").and_then(|m| m.as_str()).unwrap_or("");
|
||||
|
||||
let id = request
|
||||
.get("id")
|
||||
.and_then(|i| i.as_str())
|
||||
.unwrap_or("unknown");
|
||||
|
||||
let response = match method {
|
||||
"initialize" => {
|
||||
// Mock initialize response
|
||||
json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"result": {
|
||||
"serverInfo": {
|
||||
"name": "Mock MCP Server",
|
||||
"version": "1.0.0"
|
||||
},
|
||||
"instructions": "Mock server for testing"
|
||||
}
|
||||
})
|
||||
}
|
||||
"tools/list" => {
|
||||
// Mock tools list response
|
||||
json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"result": {
|
||||
"tools": [
|
||||
{
|
||||
"name": "brave_web_search",
|
||||
"description": "Mock web search tool",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
"count": {"type": "integer"}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "brave_local_search",
|
||||
"description": "Mock local search tool",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
})
|
||||
}
|
||||
"tools/call" => {
|
||||
// Mock tool call response
|
||||
let empty_json = json!({});
|
||||
let params = request.get("params").unwrap_or(&empty_json);
|
||||
let tool_name = params.get("name").and_then(|n| n.as_str()).unwrap_or("");
|
||||
let empty_args = json!({});
|
||||
let arguments = params.get("arguments").unwrap_or(&empty_args);
|
||||
|
||||
match tool_name {
|
||||
"brave_web_search" => {
|
||||
let query = arguments
|
||||
.get("query")
|
||||
.and_then(|q| q.as_str())
|
||||
.unwrap_or("test");
|
||||
json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"result": {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": format!("Mock search results for: {}", query)
|
||||
}
|
||||
],
|
||||
"isError": false
|
||||
}
|
||||
})
|
||||
}
|
||||
"brave_local_search" => {
|
||||
json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"result": {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Mock local search results"
|
||||
}
|
||||
],
|
||||
"isError": false
|
||||
}
|
||||
})
|
||||
}
|
||||
_ => {
|
||||
// Unknown tool
|
||||
json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"error": {
|
||||
"code": -1,
|
||||
"message": format!("Unknown tool: {}", tool_name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// Unknown method
|
||||
json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"error": {
|
||||
"code": -32601,
|
||||
"message": format!("Method not found: {}", method)
|
||||
}
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
Ok(ResponseJson(response))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[allow(unused_imports)]
|
||||
mod tests {
|
||||
#[allow(unused_imports)]
|
||||
use super::MockMCPServer;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mock_server_startup() {
|
||||
@@ -205,32 +148,32 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mock_server_responses() {
|
||||
async fn test_mock_server_with_rmcp_client() {
|
||||
let mut server = MockMCPServer::start().await.unwrap();
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
// Test initialize
|
||||
let init_request = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": "1",
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {}
|
||||
// Test that we can connect with rmcp client
|
||||
use rmcp::transport::StreamableHttpClientTransport;
|
||||
use rmcp::ServiceExt;
|
||||
|
||||
let transport = StreamableHttpClientTransport::from_uri(server.url().as_str());
|
||||
let client = ().serve(transport).await;
|
||||
|
||||
assert!(client.is_ok(), "Should be able to connect to mock server");
|
||||
|
||||
if let Ok(client) = client {
|
||||
// Test listing tools
|
||||
let tools = client.peer().list_all_tools().await;
|
||||
assert!(tools.is_ok(), "Should be able to list tools");
|
||||
|
||||
if let Ok(tools) = tools {
|
||||
assert_eq!(tools.len(), 2, "Should have 2 tools");
|
||||
assert!(tools.iter().any(|t| t.name == "brave_web_search"));
|
||||
assert!(tools.iter().any(|t| t.name == "brave_local_search"));
|
||||
}
|
||||
});
|
||||
|
||||
let response = client
|
||||
.post(server.url())
|
||||
.json(&init_request)
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(response.status().is_success());
|
||||
let json: Value = response.json().await.unwrap();
|
||||
assert_eq!(json["jsonrpc"], "2.0");
|
||||
assert_eq!(json["result"]["serverInfo"]["name"], "Mock MCP Server");
|
||||
// Shutdown by dropping the client
|
||||
drop(client);
|
||||
}
|
||||
|
||||
server.stop().await;
|
||||
}
|
||||
|
||||
@@ -2,18 +2,19 @@
|
||||
// functionality required for SGLang responses API integration.
|
||||
//
|
||||
// Test Coverage:
|
||||
// - Core MCP server functionality (Python tool_server.py parity)
|
||||
// - Core MCP server functionality
|
||||
// - Tool session management (individual and multi-tool)
|
||||
// - Tool execution and error handling
|
||||
// - Schema adaptation and validation
|
||||
// - SSE parsing and protocol compliance
|
||||
// - Mock server integration for reliable testing
|
||||
|
||||
mod common;
|
||||
|
||||
use common::mock_mcp_server::MockMCPServer;
|
||||
use serde_json::json;
|
||||
use sglang_router_rs::mcp::{parse_sse_event, MCPToolServer, MultiToolSessionManager, ToolSession};
|
||||
use sglang_router_rs::mcp::{McpClientManager, McpConfig, McpError, McpServerConfig, McpTransport};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Create a new mock server for testing (each test gets its own)
|
||||
async fn create_mock_server() -> MockMCPServer {
|
||||
MockMCPServer::start()
|
||||
@@ -21,49 +22,69 @@ async fn create_mock_server() -> MockMCPServer {
|
||||
.expect("Failed to start mock MCP server")
|
||||
}
|
||||
|
||||
// Core MCP Server Tests (Python parity)
|
||||
// Core MCP Server Tests
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mcp_server_initialization() {
|
||||
let server = MCPToolServer::new();
|
||||
// Test that we can create an empty configuration
|
||||
let config = McpConfig { servers: vec![] };
|
||||
|
||||
assert!(!server.has_tool("any_tool"));
|
||||
assert_eq!(server.list_tools().len(), 0);
|
||||
assert_eq!(server.list_servers().len(), 0);
|
||||
|
||||
let stats = server.get_tool_stats();
|
||||
assert_eq!(stats.total_tools, 0);
|
||||
assert_eq!(stats.total_servers, 0);
|
||||
// Should fail with no servers
|
||||
let result = McpClientManager::new(config).await;
|
||||
assert!(result.is_err(), "Should fail with no servers configured");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_server_connection_with_mock() {
|
||||
let mock_server = create_mock_server().await;
|
||||
let mut mcp_server = MCPToolServer::new();
|
||||
|
||||
let result = mcp_server.add_tool_server(mock_server.url()).await;
|
||||
let config = McpConfig {
|
||||
servers: vec![McpServerConfig {
|
||||
name: "mock_server".to_string(),
|
||||
transport: McpTransport::Streamable {
|
||||
url: mock_server.url(),
|
||||
token: None,
|
||||
},
|
||||
}],
|
||||
};
|
||||
|
||||
let result = McpClientManager::new(config).await;
|
||||
assert!(result.is_ok(), "Should connect to mock server");
|
||||
|
||||
let stats = mcp_server.get_tool_stats();
|
||||
assert_eq!(stats.total_tools, 2);
|
||||
assert_eq!(stats.total_servers, 1);
|
||||
let mut manager = result.unwrap();
|
||||
|
||||
assert!(mcp_server.has_tool("brave_web_search"));
|
||||
assert!(mcp_server.has_tool("brave_local_search"));
|
||||
let servers = manager.list_servers();
|
||||
assert_eq!(servers.len(), 1);
|
||||
assert!(servers.contains(&"mock_server".to_string()));
|
||||
|
||||
let tools = manager.list_tools();
|
||||
assert_eq!(tools.len(), 2, "Should have 2 tools from mock server");
|
||||
|
||||
assert!(manager.has_tool("brave_web_search"));
|
||||
assert!(manager.has_tool("brave_local_search"));
|
||||
|
||||
manager.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tool_availability_checking() {
|
||||
let mock_server = create_mock_server().await;
|
||||
let mut mcp_server = MCPToolServer::new();
|
||||
|
||||
assert!(!mcp_server.has_tool("brave_web_search"));
|
||||
let config = McpConfig {
|
||||
servers: vec![McpServerConfig {
|
||||
name: "mock_server".to_string(),
|
||||
transport: McpTransport::Streamable {
|
||||
url: mock_server.url(),
|
||||
token: None,
|
||||
},
|
||||
}],
|
||||
};
|
||||
|
||||
mcp_server.add_tool_server(mock_server.url()).await.unwrap();
|
||||
let mut manager = McpClientManager::new(config).await.unwrap();
|
||||
|
||||
let test_tools = vec!["brave_web_search", "brave_local_search", "calculator"];
|
||||
for tool in test_tools {
|
||||
let available = mcp_server.has_tool(tool);
|
||||
let available = manager.has_tool(tool);
|
||||
match tool {
|
||||
"brave_web_search" | "brave_local_search" => {
|
||||
assert!(
|
||||
@@ -82,90 +103,77 @@ async fn test_tool_availability_checking() {
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
manager.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multi_server_url_parsing() {
|
||||
async fn test_multi_server_connection() {
|
||||
let mock_server1 = create_mock_server().await;
|
||||
let mock_server2 = create_mock_server().await;
|
||||
let mut mcp_server = MCPToolServer::new();
|
||||
|
||||
let combined_urls = format!("{},{}", mock_server1.url(), mock_server2.url());
|
||||
let result = mcp_server.add_tool_server(combined_urls).await;
|
||||
assert!(result.is_ok(), "Should connect to multiple servers");
|
||||
let config = McpConfig {
|
||||
servers: vec![
|
||||
McpServerConfig {
|
||||
name: "mock_server_1".to_string(),
|
||||
transport: McpTransport::Streamable {
|
||||
url: mock_server1.url(),
|
||||
token: None,
|
||||
},
|
||||
},
|
||||
McpServerConfig {
|
||||
name: "mock_server_2".to_string(),
|
||||
transport: McpTransport::Streamable {
|
||||
url: mock_server2.url(),
|
||||
token: None,
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
let stats = mcp_server.get_tool_stats();
|
||||
assert!(stats.total_servers >= 1);
|
||||
assert!(stats.total_tools >= 2);
|
||||
}
|
||||
// Note: This will fail to connect to both servers in the current implementation
|
||||
// since they return the same tools. The manager will connect to the first one.
|
||||
let result = McpClientManager::new(config).await;
|
||||
|
||||
// Tool Session Management Tests
|
||||
if let Ok(mut manager) = result {
|
||||
let servers = manager.list_servers();
|
||||
assert!(!servers.is_empty(), "Should have at least one server");
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_individual_tool_session_creation() {
|
||||
let mock_server = create_mock_server().await;
|
||||
let mut mcp_server = MCPToolServer::new();
|
||||
let tools = manager.list_tools();
|
||||
assert!(tools.len() >= 2, "Should have tools from servers");
|
||||
|
||||
mcp_server.add_tool_server(mock_server.url()).await.unwrap();
|
||||
|
||||
let session_result = mcp_server.get_tool_session("brave_web_search").await;
|
||||
assert!(session_result.is_ok(), "Should create tool session");
|
||||
|
||||
let session = session_result.unwrap();
|
||||
assert!(session.is_ready(), "Session should be ready");
|
||||
assert!(session.connection_info().contains("HTTP"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multi_tool_session_manager() {
|
||||
let mock_server = create_mock_server().await;
|
||||
let mut mcp_server = MCPToolServer::new();
|
||||
|
||||
mcp_server.add_tool_server(mock_server.url()).await.unwrap();
|
||||
let available_tools = mcp_server.list_tools();
|
||||
assert!(
|
||||
!available_tools.is_empty(),
|
||||
"Should have tools from mock server"
|
||||
);
|
||||
|
||||
let session_manager_result = mcp_server
|
||||
.create_multi_tool_session(available_tools.clone())
|
||||
.await;
|
||||
assert!(
|
||||
session_manager_result.is_ok(),
|
||||
"Should create session manager"
|
||||
);
|
||||
|
||||
let session_manager = session_manager_result.unwrap();
|
||||
|
||||
for tool in &available_tools {
|
||||
assert!(session_manager.has_tool(tool));
|
||||
manager.shutdown().await;
|
||||
}
|
||||
|
||||
let stats = session_manager.session_stats();
|
||||
// After optimization: 1 session per server (not per tool)
|
||||
assert_eq!(stats.total_sessions, 1); // One session for the mock server
|
||||
assert_eq!(stats.ready_sessions, 1); // One ready session
|
||||
assert_eq!(stats.unique_servers, 1); // One unique server
|
||||
|
||||
// But we still have all tools available
|
||||
assert_eq!(session_manager.list_tools().len(), available_tools.len());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tool_execution_with_mock() {
|
||||
let mock_server = create_mock_server().await;
|
||||
let mut mcp_server = MCPToolServer::new();
|
||||
|
||||
mcp_server.add_tool_server(mock_server.url()).await.unwrap();
|
||||
let config = McpConfig {
|
||||
servers: vec![McpServerConfig {
|
||||
name: "mock_server".to_string(),
|
||||
transport: McpTransport::Streamable {
|
||||
url: mock_server.url(),
|
||||
token: None,
|
||||
},
|
||||
}],
|
||||
};
|
||||
|
||||
let result = mcp_server
|
||||
let mut manager = McpClientManager::new(config).await.unwrap();
|
||||
|
||||
let result = manager
|
||||
.call_tool(
|
||||
"brave_web_search",
|
||||
json!({
|
||||
"query": "rust programming",
|
||||
"count": 1
|
||||
}),
|
||||
Some(
|
||||
json!({
|
||||
"query": "rust programming",
|
||||
"count": 1
|
||||
})
|
||||
.as_object()
|
||||
.unwrap()
|
||||
.clone(),
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -175,48 +183,53 @@ async fn test_tool_execution_with_mock() {
|
||||
);
|
||||
|
||||
let response = result.unwrap();
|
||||
assert!(
|
||||
response.get("content").is_some(),
|
||||
"Response should have content"
|
||||
);
|
||||
assert_eq!(response.get("isError").unwrap(), false);
|
||||
assert!(!response.content.is_empty(), "Should have content");
|
||||
|
||||
let content = response.get("content").unwrap().as_array().unwrap();
|
||||
let text = content[0].get("text").unwrap().as_str().unwrap();
|
||||
assert!(text.contains("Mock search results for: rust programming"));
|
||||
// Check the content
|
||||
if let rmcp::model::RawContent::Text(text) = &response.content[0].raw {
|
||||
assert!(text
|
||||
.text
|
||||
.contains("Mock search results for: rust programming"));
|
||||
} else {
|
||||
panic!("Expected text content");
|
||||
}
|
||||
|
||||
manager.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_concurrent_tool_execution() {
|
||||
let mock_server = create_mock_server().await;
|
||||
let mut session_manager = MultiToolSessionManager::new();
|
||||
|
||||
session_manager
|
||||
.add_tools_from_server(
|
||||
mock_server.url(),
|
||||
vec![
|
||||
"brave_web_search".to_string(),
|
||||
"brave_local_search".to_string(),
|
||||
],
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let config = McpConfig {
|
||||
servers: vec![McpServerConfig {
|
||||
name: "mock_server".to_string(),
|
||||
transport: McpTransport::Streamable {
|
||||
url: mock_server.url(),
|
||||
token: None,
|
||||
},
|
||||
}],
|
||||
};
|
||||
|
||||
let mut manager = McpClientManager::new(config).await.unwrap();
|
||||
|
||||
// Execute tools sequentially (true concurrent execution would require Arc<Mutex>)
|
||||
let tool_calls = vec![
|
||||
("brave_web_search".to_string(), json!({"query": "test1"})),
|
||||
("brave_local_search".to_string(), json!({"query": "test2"})),
|
||||
("brave_web_search", json!({"query": "test1"})),
|
||||
("brave_local_search", json!({"query": "test2"})),
|
||||
];
|
||||
|
||||
let results = session_manager.call_tools_concurrent(tool_calls).await;
|
||||
assert_eq!(results.len(), 2, "Should return results for both tools");
|
||||
for (tool_name, args) in tool_calls {
|
||||
let result = manager
|
||||
.call_tool(tool_name, Some(args.as_object().unwrap().clone()))
|
||||
.await;
|
||||
|
||||
for (i, result) in results.iter().enumerate() {
|
||||
assert!(result.is_ok(), "Tool {} should succeed with mock server", i);
|
||||
|
||||
let response = result.as_ref().unwrap();
|
||||
assert!(response.get("content").is_some());
|
||||
assert_eq!(response.get("isError").unwrap(), false);
|
||||
assert!(result.is_ok(), "Tool {} should succeed", tool_name);
|
||||
let response = result.unwrap();
|
||||
assert!(!response.content.is_empty(), "Should have content");
|
||||
}
|
||||
|
||||
manager.shutdown().await;
|
||||
}
|
||||
|
||||
// Error Handling Tests
|
||||
@@ -224,235 +237,221 @@ async fn test_concurrent_tool_execution() {
|
||||
#[tokio::test]
|
||||
async fn test_tool_execution_errors() {
|
||||
let mock_server = create_mock_server().await;
|
||||
let mut mcp_server = MCPToolServer::new();
|
||||
|
||||
mcp_server.add_tool_server(mock_server.url()).await.unwrap();
|
||||
let config = McpConfig {
|
||||
servers: vec![McpServerConfig {
|
||||
name: "mock_server".to_string(),
|
||||
transport: McpTransport::Streamable {
|
||||
url: mock_server.url(),
|
||||
token: None,
|
||||
},
|
||||
}],
|
||||
};
|
||||
|
||||
let result = mcp_server.call_tool("unknown_tool", json!({})).await;
|
||||
let mut manager = McpClientManager::new(config).await.unwrap();
|
||||
|
||||
// Try to call unknown tool
|
||||
let result = manager
|
||||
.call_tool("unknown_tool", Some(serde_json::Map::new()))
|
||||
.await;
|
||||
assert!(result.is_err(), "Should fail for unknown tool");
|
||||
|
||||
let session = mcp_server
|
||||
.get_tool_session("brave_web_search")
|
||||
.await
|
||||
.unwrap();
|
||||
let session_result = session.call_tool("unknown_tool", json!({})).await;
|
||||
assert!(
|
||||
session_result.is_err(),
|
||||
"Session should fail for unknown tool"
|
||||
);
|
||||
match result.unwrap_err() {
|
||||
McpError::ToolNotFound(name) => {
|
||||
assert_eq!(name, "unknown_tool");
|
||||
}
|
||||
_ => panic!("Expected ToolNotFound error"),
|
||||
}
|
||||
|
||||
manager.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_connection_without_server() {
|
||||
let mut server = MCPToolServer::new();
|
||||
let config = McpConfig {
|
||||
servers: vec![McpServerConfig {
|
||||
name: "nonexistent".to_string(),
|
||||
transport: McpTransport::Streamable {
|
||||
url: "http://localhost:9999/mcp".to_string(),
|
||||
token: None,
|
||||
},
|
||||
}],
|
||||
};
|
||||
|
||||
let result = server
|
||||
.add_tool_server("http://localhost:9999/mcp".to_string())
|
||||
.await;
|
||||
let result = McpClientManager::new(config).await;
|
||||
assert!(result.is_err(), "Should fail when no server is running");
|
||||
|
||||
let error_msg = result.unwrap_err().to_string();
|
||||
assert!(
|
||||
error_msg.contains("Failed to connect") || error_msg.contains("Connection"),
|
||||
"Error should be connection-related: {}",
|
||||
error_msg
|
||||
);
|
||||
if let Err(e) = result {
|
||||
let error_msg = e.to_string();
|
||||
assert!(
|
||||
error_msg.contains("Failed to connect") || error_msg.contains("Connection"),
|
||||
"Error should be connection-related: {}",
|
||||
error_msg
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Schema Adaptation Tests
|
||||
// Schema Validation Tests
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_schema_validation() {
|
||||
async fn test_tool_info_structure() {
|
||||
let mock_server = create_mock_server().await;
|
||||
let mut mcp_server = MCPToolServer::new();
|
||||
|
||||
mcp_server.add_tool_server(mock_server.url()).await.unwrap();
|
||||
let config = McpConfig {
|
||||
servers: vec![McpServerConfig {
|
||||
name: "mock_server".to_string(),
|
||||
transport: McpTransport::Streamable {
|
||||
url: mock_server.url(),
|
||||
token: None,
|
||||
},
|
||||
}],
|
||||
};
|
||||
|
||||
let description = mcp_server.get_tool_description("brave_web_search");
|
||||
assert!(description.is_some(), "Should have tool description");
|
||||
let manager = McpClientManager::new(config).await.unwrap();
|
||||
|
||||
let desc_value = description.unwrap();
|
||||
assert!(desc_value.get("name").is_some());
|
||||
assert!(desc_value.get("description").is_some());
|
||||
let tools = manager.list_tools();
|
||||
let brave_search = tools
|
||||
.iter()
|
||||
.find(|t| t.name == "brave_web_search")
|
||||
.expect("Should have brave_web_search tool");
|
||||
|
||||
assert_eq!(brave_search.name, "brave_web_search");
|
||||
assert!(brave_search.description.contains("Mock web search"));
|
||||
assert_eq!(brave_search.server, "mock_server");
|
||||
assert!(brave_search.parameters.is_some());
|
||||
}
|
||||
|
||||
// SSE Parsing Tests
|
||||
// SSE Parsing Tests (simplified since we don't expose parse_sse_event)
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sse_event_parsing_success() {
|
||||
let valid_event = "data: {\"jsonrpc\": \"2.0\", \"id\": \"1\", \"result\": {\"test\": \"success\", \"content\": [{\"type\": \"text\", \"text\": \"Hello\"}]}}";
|
||||
async fn test_sse_connection() {
|
||||
let mock_server = create_mock_server().await;
|
||||
|
||||
let result = parse_sse_event(valid_event);
|
||||
assert!(result.is_ok(), "Valid SSE event should parse successfully");
|
||||
// Test SSE transport configuration
|
||||
let config = McpConfig {
|
||||
servers: vec![McpServerConfig {
|
||||
name: "sse_server".to_string(),
|
||||
transport: McpTransport::Sse {
|
||||
// Mock server doesn't support SSE, but we can test the config
|
||||
url: format!("http://127.0.0.1:{}/sse", mock_server.port),
|
||||
token: Some("test_token".to_string()),
|
||||
},
|
||||
}],
|
||||
};
|
||||
|
||||
let parsed = result.unwrap();
|
||||
assert!(parsed.is_some(), "Should return parsed data");
|
||||
|
||||
let response = parsed.unwrap();
|
||||
assert_eq!(response["test"], "success");
|
||||
assert!(response.get("content").is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sse_event_parsing_error() {
|
||||
let error_event = "data: {\"jsonrpc\": \"2.0\", \"id\": \"1\", \"error\": {\"code\": -1, \"message\": \"Rate limit exceeded\"}}";
|
||||
|
||||
let result = parse_sse_event(error_event);
|
||||
assert!(result.is_err(), "Error SSE event should return error");
|
||||
|
||||
let error_msg = result.unwrap_err().to_string();
|
||||
assert!(
|
||||
error_msg.contains("Rate limit exceeded"),
|
||||
"Should contain error message"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sse_event_parsing_empty() {
|
||||
let empty_event = "";
|
||||
let result = parse_sse_event(empty_event);
|
||||
assert!(result.is_ok(), "Empty event should parse successfully");
|
||||
assert!(result.unwrap().is_none(), "Empty event should return None");
|
||||
|
||||
let no_data_event = "event: ping\nid: 123";
|
||||
let result2 = parse_sse_event(no_data_event);
|
||||
assert!(result2.is_ok(), "Non-data event should parse successfully");
|
||||
assert!(
|
||||
result2.unwrap().is_none(),
|
||||
"Non-data event should return None"
|
||||
);
|
||||
// This will fail to connect but tests the configuration
|
||||
let result = McpClientManager::new(config).await;
|
||||
assert!(result.is_err(), "Mock server doesn't support SSE");
|
||||
}
|
||||
|
||||
// Connection Type Tests
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_connection_type_detection() {
|
||||
let mock_server = create_mock_server().await;
|
||||
async fn test_transport_types() {
|
||||
// Test different transport configurations
|
||||
|
||||
let session_result = ToolSession::new(mock_server.url()).await;
|
||||
assert!(session_result.is_ok(), "Should create HTTP session");
|
||||
// HTTP/Streamable transport
|
||||
let http_config = McpServerConfig {
|
||||
name: "http_server".to_string(),
|
||||
transport: McpTransport::Streamable {
|
||||
url: "http://localhost:8080/mcp".to_string(),
|
||||
token: Some("auth_token".to_string()),
|
||||
},
|
||||
};
|
||||
assert_eq!(http_config.name, "http_server");
|
||||
|
||||
let session = session_result.unwrap();
|
||||
assert!(session.connection_info().contains("HTTP"));
|
||||
assert!(session.is_ready(), "HTTP session should be ready");
|
||||
// SSE transport
|
||||
let sse_config = McpServerConfig {
|
||||
name: "sse_server".to_string(),
|
||||
transport: McpTransport::Sse {
|
||||
url: "http://localhost:8081/sse".to_string(),
|
||||
token: None,
|
||||
},
|
||||
};
|
||||
assert_eq!(sse_config.name, "sse_server");
|
||||
|
||||
// Stdio sessions are no longer supported - test invalid URL handling
|
||||
let invalid_session = ToolSession::new("invalid-url".to_string()).await;
|
||||
assert!(invalid_session.is_err(), "Should reject non-HTTP URLs");
|
||||
// STDIO transport
|
||||
let stdio_config = McpServerConfig {
|
||||
name: "stdio_server".to_string(),
|
||||
transport: McpTransport::Stdio {
|
||||
command: "mcp-server".to_string(),
|
||||
args: vec!["--port".to_string(), "8082".to_string()],
|
||||
envs: HashMap::new(),
|
||||
},
|
||||
};
|
||||
assert_eq!(stdio_config.name, "stdio_server");
|
||||
}
|
||||
|
||||
// Integration Pattern Tests
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_responses_api_integration_patterns() {
|
||||
async fn test_complete_workflow() {
|
||||
let mock_server = create_mock_server().await;
|
||||
|
||||
// Server initialization
|
||||
let mut mcp_server = MCPToolServer::new();
|
||||
// 1. Initialize configuration
|
||||
let config = McpConfig {
|
||||
servers: vec![McpServerConfig {
|
||||
name: "integration_test".to_string(),
|
||||
transport: McpTransport::Streamable {
|
||||
url: mock_server.url(),
|
||||
token: None,
|
||||
},
|
||||
}],
|
||||
};
|
||||
|
||||
// Tool server connection (like responses API startup)
|
||||
match mcp_server.add_tool_server(mock_server.url()).await {
|
||||
Ok(_) => {
|
||||
let stats = mcp_server.get_tool_stats();
|
||||
assert_eq!(stats.total_tools, 2);
|
||||
assert_eq!(stats.total_servers, 1);
|
||||
}
|
||||
Err(e) => {
|
||||
panic!("Should connect to mock server: {}", e);
|
||||
}
|
||||
}
|
||||
// 2. Connect to server
|
||||
let mut manager = McpClientManager::new(config)
|
||||
.await
|
||||
.expect("Should connect to mock server");
|
||||
|
||||
// Tool availability checking
|
||||
let test_tools = vec!["brave_web_search", "brave_local_search", "calculator"];
|
||||
for tool in &test_tools {
|
||||
let _available = mcp_server.has_tool(tool);
|
||||
}
|
||||
// 3. Verify server connection
|
||||
let servers = manager.list_servers();
|
||||
assert_eq!(servers.len(), 1);
|
||||
assert_eq!(servers[0], "integration_test");
|
||||
|
||||
// Tool session creation
|
||||
if mcp_server.has_tool("brave_web_search") {
|
||||
let session_result = mcp_server.get_tool_session("brave_web_search").await;
|
||||
assert!(session_result.is_ok(), "Should create tool session");
|
||||
}
|
||||
// 4. Check available tools
|
||||
let tools = manager.list_tools();
|
||||
assert_eq!(tools.len(), 2);
|
||||
|
||||
// Multi-tool session creation
|
||||
let available_tools = mcp_server.list_tools();
|
||||
if !available_tools.is_empty() {
|
||||
let session_manager_result = mcp_server.create_multi_tool_session(available_tools).await;
|
||||
assert!(
|
||||
session_manager_result.is_ok(),
|
||||
"Should create multi-tool session"
|
||||
);
|
||||
}
|
||||
// 5. Verify specific tools exist
|
||||
assert!(manager.has_tool("brave_web_search"));
|
||||
assert!(manager.has_tool("brave_local_search"));
|
||||
assert!(!manager.has_tool("nonexistent_tool"));
|
||||
|
||||
// Tool execution
|
||||
let result = mcp_server
|
||||
// 6. Execute a tool
|
||||
let result = manager
|
||||
.call_tool(
|
||||
"brave_web_search",
|
||||
json!({
|
||||
"query": "SGLang router MCP integration",
|
||||
"count": 1
|
||||
}),
|
||||
Some(
|
||||
json!({
|
||||
"query": "SGLang router MCP integration",
|
||||
"count": 1
|
||||
})
|
||||
.as_object()
|
||||
.unwrap()
|
||||
.clone(),
|
||||
),
|
||||
)
|
||||
.await;
|
||||
if result.is_err() {
|
||||
// This might fail if called after another test that uses the same tool name
|
||||
// Due to the shared mock server. That's OK, the main test covers this.
|
||||
return;
|
||||
}
|
||||
assert!(result.is_ok(), "Should execute tool successfully");
|
||||
}
|
||||
|
||||
// Complete Integration Test
|
||||
assert!(result.is_ok(), "Tool execution should succeed");
|
||||
let response = result.unwrap();
|
||||
assert!(!response.content.is_empty(), "Should return content");
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_responses_api_integration() {
|
||||
let mock_server = create_mock_server().await;
|
||||
|
||||
// Run through all functionality required for responses API integration
|
||||
let mut mcp_server = MCPToolServer::new();
|
||||
mcp_server.add_tool_server(mock_server.url()).await.unwrap();
|
||||
|
||||
// Test all core functionality
|
||||
assert!(mcp_server.has_tool("brave_web_search"));
|
||||
|
||||
let session = mcp_server
|
||||
.get_tool_session("brave_web_search")
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(session.is_ready());
|
||||
|
||||
let session_manager = mcp_server
|
||||
.create_multi_tool_session(mcp_server.list_tools())
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(session_manager.session_stats().total_sessions > 0);
|
||||
|
||||
let result = mcp_server
|
||||
.call_tool(
|
||||
"brave_web_search",
|
||||
json!({
|
||||
"query": "test",
|
||||
"count": 1
|
||||
}),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.get("content").is_some());
|
||||
// 7. Clean shutdown
|
||||
manager.shutdown().await;
|
||||
|
||||
// Verify all required capabilities for responses API integration
|
||||
let capabilities = [
|
||||
"MCP server initialization",
|
||||
"Tool server connection and discovery",
|
||||
"Tool availability checking",
|
||||
"Individual tool session management",
|
||||
"Multi-tool session manager (Python tool_session_ctxs pattern)",
|
||||
"Concurrent tool execution",
|
||||
"Direct tool execution",
|
||||
"Tool execution",
|
||||
"Error handling and robustness",
|
||||
"Protocol compliance (SSE parsing)",
|
||||
"Schema adaptation (Python parity)",
|
||||
"Multi-server support",
|
||||
"Schema adaptation",
|
||||
"Mock server integration (no external dependencies)",
|
||||
];
|
||||
|
||||
assert_eq!(capabilities.len(), 11);
|
||||
assert_eq!(capabilities.len(), 8);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user