[router] Add MCP Tool Handler (#9615)
This commit is contained in:
@@ -5,6 +5,7 @@ use std::collections::HashMap;
|
||||
pub mod core;
|
||||
#[cfg(feature = "grpc-client")]
|
||||
pub mod grpc;
|
||||
pub mod mcp;
|
||||
pub mod metrics;
|
||||
pub mod middleware;
|
||||
pub mod policies;
|
||||
|
||||
9
sgl-router/src/mcp/mod.rs
Normal file
9
sgl-router/src/mcp/mod.rs
Normal file
@@ -0,0 +1,9 @@
|
||||
// mod.rs - MCP module exports
|
||||
pub mod tool_server;
|
||||
pub mod types;
|
||||
|
||||
pub use tool_server::{parse_sse_event, MCPToolServer, ToolStats};
|
||||
pub use types::{
|
||||
HttpConnection, MCPError, MCPResult, MultiToolSessionManager, SessionStats, ToolCall,
|
||||
ToolResult, ToolSession,
|
||||
};
|
||||
534
sgl-router/src/mcp/tool_server.rs
Normal file
534
sgl-router/src/mcp/tool_server.rs
Normal file
@@ -0,0 +1,534 @@
|
||||
// tool_server.rs - Main MCP implementation (matching Python's tool_server.py)
|
||||
use crate::mcp::types::*;
|
||||
use serde_json::{json, Value};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Main MCP Tool Server
|
||||
pub struct MCPToolServer {
|
||||
/// Tool descriptions by server
|
||||
tool_descriptions: HashMap<String, Value>,
|
||||
/// Server URLs
|
||||
urls: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl Default for MCPToolServer {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl MCPToolServer {
|
||||
/// Create new MCPToolServer
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
tool_descriptions: HashMap::new(),
|
||||
urls: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Clears all existing tool servers and adds new ones from the provided URL(s).
|
||||
/// URLs can be a single string or multiple comma-separated strings.
|
||||
pub async fn add_tool_server(&mut self, server_url: String) -> MCPResult<()> {
|
||||
let tool_urls: Vec<&str> = server_url.split(",").collect();
|
||||
let mut successful_connections = 0;
|
||||
let mut errors = Vec::new();
|
||||
|
||||
// Clear existing
|
||||
self.tool_descriptions = HashMap::new();
|
||||
self.urls = HashMap::new();
|
||||
|
||||
for url_str in tool_urls {
|
||||
let url_str = url_str.trim();
|
||||
|
||||
// Format URL for MCP-compliant connection
|
||||
let formatted_url = if url_str.starts_with("http://") || url_str.starts_with("https://")
|
||||
{
|
||||
url_str.to_string()
|
||||
} else {
|
||||
// Default to MCP endpoint if no protocol specified
|
||||
format!("http://{}", url_str)
|
||||
};
|
||||
|
||||
// Server connection with retry and error recovery
|
||||
match self.connect_to_server(&formatted_url).await {
|
||||
Ok((_init_response, tools_response)) => {
|
||||
// Process tools with validation
|
||||
let tools_obj = post_process_tools_description(tools_response);
|
||||
|
||||
// Tool storage with conflict detection
|
||||
for tool in &tools_obj.tools {
|
||||
let tool_name = &tool.name;
|
||||
|
||||
// Check for duplicate tools
|
||||
if self.tool_descriptions.contains_key(tool_name) {
|
||||
tracing::warn!(
|
||||
"Tool {} already exists. Ignoring duplicate tool from server {}",
|
||||
tool_name,
|
||||
formatted_url
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Store individual tool descriptions
|
||||
let tool_json = json!(tool);
|
||||
self.tool_descriptions
|
||||
.insert(tool_name.clone(), tool_json.clone());
|
||||
self.urls.insert(tool_name.clone(), formatted_url.clone());
|
||||
}
|
||||
|
||||
successful_connections += 1;
|
||||
}
|
||||
Err(e) => {
|
||||
errors.push(format!("Failed to connect to {}: {}", formatted_url, e));
|
||||
tracing::warn!("Failed to connect to MCP server {}: {}", formatted_url, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Error handling - succeed if at least one server connects
|
||||
if successful_connections == 0 {
|
||||
let combined_error = errors.join("; ");
|
||||
return Err(MCPError::ConnectionError(format!(
|
||||
"Failed to connect to any MCP servers: {}",
|
||||
combined_error
|
||||
)));
|
||||
}
|
||||
|
||||
if !errors.is_empty() {
|
||||
tracing::warn!("Some MCP servers failed to connect: {}", errors.join("; "));
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
"Successfully connected to {} MCP server(s), discovered {} tool(s)",
|
||||
successful_connections,
|
||||
self.tool_descriptions.len()
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Server connection with retries (internal helper)
|
||||
async fn connect_to_server(
|
||||
&self,
|
||||
url: &str,
|
||||
) -> MCPResult<(InitializeResponse, ListToolsResponse)> {
|
||||
const MAX_RETRIES: u32 = 3;
|
||||
const RETRY_DELAY_MS: u64 = 1000;
|
||||
|
||||
let mut last_error = None;
|
||||
|
||||
for attempt in 1..=MAX_RETRIES {
|
||||
match list_server_and_tools(url).await {
|
||||
Ok(result) => return Ok(result),
|
||||
Err(e) => {
|
||||
last_error = Some(e);
|
||||
if attempt < MAX_RETRIES {
|
||||
tracing::debug!(
|
||||
"MCP server connection attempt {}/{} failed for {}: {}. Retrying...",
|
||||
attempt,
|
||||
MAX_RETRIES,
|
||||
url,
|
||||
last_error.as_ref().unwrap()
|
||||
);
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(
|
||||
RETRY_DELAY_MS * attempt as u64,
|
||||
))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(last_error.unwrap())
|
||||
}
|
||||
|
||||
/// Check if tool exists (matching Python's has_tool)
|
||||
pub fn has_tool(&self, tool_name: &str) -> bool {
|
||||
self.tool_descriptions.contains_key(tool_name)
|
||||
}
|
||||
|
||||
/// Get tool description (matching Python's get_tool_description)
|
||||
pub fn get_tool_description(&self, tool_name: &str) -> Option<&Value> {
|
||||
self.tool_descriptions.get(tool_name)
|
||||
}
|
||||
|
||||
/// Get tool session (matching Python's get_tool_session)
|
||||
pub async fn get_tool_session(&self, tool_name: &str) -> MCPResult<ToolSession> {
|
||||
let url = self
|
||||
.urls
|
||||
.get(tool_name)
|
||||
.ok_or_else(|| MCPError::ToolNotFound(tool_name.to_string()))?;
|
||||
|
||||
// Create session
|
||||
ToolSession::new(url.clone()).await
|
||||
}
|
||||
|
||||
/// Create multi-tool session manager
|
||||
pub async fn create_multi_tool_session(
|
||||
&self,
|
||||
tool_names: Vec<String>,
|
||||
) -> MCPResult<MultiToolSessionManager> {
|
||||
let mut session_manager = MultiToolSessionManager::new();
|
||||
|
||||
// Group tools by server URL for efficient session creation
|
||||
let mut server_tools: std::collections::HashMap<String, Vec<String>> =
|
||||
std::collections::HashMap::new();
|
||||
|
||||
for tool_name in tool_names {
|
||||
if let Some(url) = self.urls.get(&tool_name) {
|
||||
server_tools.entry(url.clone()).or_default().push(tool_name);
|
||||
} else {
|
||||
return Err(MCPError::ToolNotFound(format!(
|
||||
"Tool not found: {}",
|
||||
tool_name
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// Create sessions for each server
|
||||
for (server_url, tools) in server_tools {
|
||||
session_manager
|
||||
.add_tools_from_server(server_url, tools)
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(session_manager)
|
||||
}
|
||||
|
||||
/// List all available tools
|
||||
pub fn list_tools(&self) -> Vec<String> {
|
||||
self.tool_descriptions.keys().cloned().collect()
|
||||
}
|
||||
|
||||
/// Get tool statistics
|
||||
pub fn get_tool_stats(&self) -> ToolStats {
|
||||
ToolStats {
|
||||
total_tools: self.tool_descriptions.len(),
|
||||
total_servers: self
|
||||
.urls
|
||||
.values()
|
||||
.collect::<std::collections::HashSet<_>>()
|
||||
.len(),
|
||||
}
|
||||
}
|
||||
|
||||
/// List all connected servers
|
||||
pub fn list_servers(&self) -> Vec<String> {
|
||||
self.urls
|
||||
.values()
|
||||
.cloned()
|
||||
.collect::<std::collections::HashSet<_>>()
|
||||
.into_iter()
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Check if a specific server is connected
|
||||
pub fn has_server(&self, server_url: &str) -> bool {
|
||||
self.urls.values().any(|url| url == server_url)
|
||||
}
|
||||
|
||||
/// Execute a tool directly (convenience method for simple usage)
|
||||
pub async fn call_tool(
|
||||
&self,
|
||||
tool_name: &str,
|
||||
arguments: serde_json::Value,
|
||||
) -> MCPResult<serde_json::Value> {
|
||||
let session = self.get_tool_session(tool_name).await?;
|
||||
session.call_tool(tool_name, arguments).await
|
||||
}
|
||||
|
||||
/// Create a tool session from server URL (convenience method)
|
||||
pub async fn create_session_from_url(&self, server_url: &str) -> MCPResult<ToolSession> {
|
||||
ToolSession::new(server_url.to_string()).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Tool statistics for monitoring
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ToolStats {
|
||||
pub total_tools: usize,
|
||||
pub total_servers: usize,
|
||||
}
|
||||
|
||||
/// MCP-compliant server connection using JSON-RPC over SSE
|
||||
async fn list_server_and_tools(
|
||||
server_url: &str,
|
||||
) -> MCPResult<(InitializeResponse, ListToolsResponse)> {
|
||||
// MCP specification:
|
||||
// 1. Connect to MCP endpoint with GET (SSE) or POST (JSON-RPC)
|
||||
// 2. Send initialize request
|
||||
// 3. Send tools/list request
|
||||
// 4. Parse JSON-RPC responses
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
// Step 1: Send initialize request
|
||||
let init_request = MCPRequest {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id: "1".to_string(),
|
||||
method: "initialize".to_string(),
|
||||
params: Some(json!({
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {}
|
||||
})),
|
||||
};
|
||||
|
||||
let init_response = send_mcp_request(&client, server_url, init_request).await?;
|
||||
let init_result: InitializeResponse = serde_json::from_value(init_response).map_err(|e| {
|
||||
MCPError::SerializationError(format!("Failed to parse initialize response: {}", e))
|
||||
})?;
|
||||
|
||||
// Step 2: Send tools/list request
|
||||
let tools_request = MCPRequest {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id: "2".to_string(),
|
||||
method: "tools/list".to_string(),
|
||||
params: Some(json!({})),
|
||||
};
|
||||
|
||||
let tools_response = send_mcp_request(&client, server_url, tools_request).await?;
|
||||
let tools_result: ListToolsResponse = serde_json::from_value(tools_response).map_err(|e| {
|
||||
MCPError::SerializationError(format!("Failed to parse tools/list response: {}", e))
|
||||
})?;
|
||||
|
||||
Ok((init_result, tools_result))
|
||||
}
|
||||
|
||||
/// Send MCP JSON-RPC request (supports both HTTP POST and SSE)
|
||||
async fn send_mcp_request(
|
||||
client: &reqwest::Client,
|
||||
url: &str,
|
||||
request: MCPRequest,
|
||||
) -> MCPResult<Value> {
|
||||
// Use HTTP POST for JSON-RPC requests
|
||||
let response = client
|
||||
.post(url)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Accept", "application/json")
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| MCPError::ConnectionError(format!("MCP request failed: {}", e)))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(MCPError::ProtocolError(format!(
|
||||
"HTTP {}",
|
||||
response.status()
|
||||
)));
|
||||
}
|
||||
|
||||
let mcp_response: MCPResponse = response.json().await.map_err(|e| {
|
||||
MCPError::SerializationError(format!("Failed to parse MCP response: {}", e))
|
||||
})?;
|
||||
|
||||
if let Some(error) = mcp_response.error {
|
||||
return Err(MCPError::ProtocolError(format!(
|
||||
"MCP error: {}",
|
||||
error.message
|
||||
)));
|
||||
}
|
||||
|
||||
mcp_response
|
||||
.result
|
||||
.ok_or_else(|| MCPError::ProtocolError("No result in MCP response".to_string()))
|
||||
}
|
||||
|
||||
// Removed old send_http_request - now using send_mcp_request with proper MCP protocol
|
||||
|
||||
/// Parse SSE event format (MCP-compliant JSON-RPC only)
|
||||
pub fn parse_sse_event(event: &str) -> MCPResult<Option<Value>> {
|
||||
let mut data_lines = Vec::new();
|
||||
|
||||
for line in event.lines() {
|
||||
if let Some(stripped) = line.strip_prefix("data: ") {
|
||||
data_lines.push(stripped);
|
||||
}
|
||||
}
|
||||
|
||||
if data_lines.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let json_data = data_lines.join("\n");
|
||||
if json_data.trim().is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Parse as MCP JSON-RPC response only (no custom events)
|
||||
let mcp_response: MCPResponse = serde_json::from_str(&json_data).map_err(|e| {
|
||||
MCPError::SerializationError(format!(
|
||||
"Failed to parse JSON-RPC response: {} - Data: {}",
|
||||
e, json_data
|
||||
))
|
||||
})?;
|
||||
|
||||
if let Some(error) = mcp_response.error {
|
||||
return Err(MCPError::ProtocolError(error.message));
|
||||
}
|
||||
|
||||
Ok(mcp_response.result)
|
||||
}
|
||||
|
||||
/// Schema adaptation matching Python's trim_schema()
|
||||
fn trim_schema(schema: &mut Value) {
|
||||
if let Some(obj) = schema.as_object_mut() {
|
||||
// Remove title and null defaults
|
||||
obj.remove("title");
|
||||
if obj.get("default") == Some(&Value::Null) {
|
||||
obj.remove("default");
|
||||
}
|
||||
|
||||
// Convert anyOf to type arrays
|
||||
if let Some(any_of) = obj.remove("anyOf") {
|
||||
if let Some(array) = any_of.as_array() {
|
||||
let types: Vec<String> = array
|
||||
.iter()
|
||||
.filter_map(|item| {
|
||||
item.get("type")
|
||||
.and_then(|t| t.as_str())
|
||||
.filter(|t| *t != "null")
|
||||
.map(|t| t.to_string())
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Handle single type vs array of types
|
||||
match types.len() {
|
||||
0 => {} // No valid types found
|
||||
1 => {
|
||||
obj.insert("type".to_string(), json!(types[0]));
|
||||
}
|
||||
_ => {
|
||||
obj.insert("type".to_string(), json!(types));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle oneOf similar to anyOf
|
||||
if let Some(one_of) = obj.remove("oneOf") {
|
||||
if let Some(array) = one_of.as_array() {
|
||||
let types: Vec<String> = array
|
||||
.iter()
|
||||
.filter_map(|item| {
|
||||
item.get("type")
|
||||
.and_then(|t| t.as_str())
|
||||
.filter(|t| *t != "null")
|
||||
.map(|t| t.to_string())
|
||||
})
|
||||
.collect();
|
||||
|
||||
if !types.is_empty() {
|
||||
obj.insert("type".to_string(), json!(types));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Recursive processing for properties
|
||||
if let Some(properties) = obj.get_mut("properties") {
|
||||
if let Some(props_obj) = properties.as_object_mut() {
|
||||
for (_, value) in props_obj.iter_mut() {
|
||||
trim_schema(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle nested schemas in items (for arrays)
|
||||
if let Some(items) = obj.get_mut("items") {
|
||||
trim_schema(items);
|
||||
}
|
||||
|
||||
// Handle nested schemas in additionalProperties
|
||||
if let Some(additional_props) = obj.get_mut("additionalProperties") {
|
||||
if additional_props.is_object() {
|
||||
trim_schema(additional_props);
|
||||
}
|
||||
}
|
||||
|
||||
// Handle patternProperties (for dynamic property names)
|
||||
if let Some(pattern_props) = obj.get_mut("patternProperties") {
|
||||
if let Some(pattern_obj) = pattern_props.as_object_mut() {
|
||||
for (_, value) in pattern_obj.iter_mut() {
|
||||
trim_schema(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle allOf in nested contexts
|
||||
if let Some(all_of) = obj.get_mut("allOf") {
|
||||
if let Some(array) = all_of.as_array_mut() {
|
||||
for item in array.iter_mut() {
|
||||
trim_schema(item);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Tool processing with filtering
|
||||
fn post_process_tools_description(mut tools_response: ListToolsResponse) -> ListToolsResponse {
|
||||
// Adapt schemas for Harmony
|
||||
for tool in &mut tools_response.tools {
|
||||
trim_schema(&mut tool.input_schema);
|
||||
}
|
||||
|
||||
// Tool filtering based on annotations
|
||||
let initial_count = tools_response.tools.len();
|
||||
|
||||
tools_response.tools.retain(|tool| {
|
||||
// Check include_in_prompt annotation (Python behavior)
|
||||
let include_in_prompt = tool
|
||||
.annotations
|
||||
.as_ref()
|
||||
.and_then(|a| a.get("include_in_prompt"))
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(true);
|
||||
|
||||
if !include_in_prompt {
|
||||
tracing::debug!(
|
||||
"Filtering out tool '{}' due to include_in_prompt=false",
|
||||
tool.name
|
||||
);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if tool is explicitly disabled
|
||||
let disabled = tool
|
||||
.annotations
|
||||
.as_ref()
|
||||
.and_then(|a| a.get("disabled"))
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false);
|
||||
|
||||
if disabled {
|
||||
tracing::debug!("Filtering out disabled tool '{}'", tool.name);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Validate tool has required fields
|
||||
if tool.name.trim().is_empty() {
|
||||
tracing::warn!("Filtering out tool with empty name");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check for valid input schema
|
||||
if tool.input_schema.is_null() {
|
||||
tracing::warn!("Tool '{}' has null input schema, but keeping it", tool.name);
|
||||
}
|
||||
|
||||
true
|
||||
});
|
||||
|
||||
let filtered_count = tools_response.tools.len();
|
||||
if filtered_count != initial_count {
|
||||
tracing::info!(
|
||||
"Filtered tools: {} -> {} ({} removed)",
|
||||
initial_count,
|
||||
filtered_count,
|
||||
initial_count - filtered_count
|
||||
);
|
||||
}
|
||||
|
||||
tools_response
|
||||
}
|
||||
|
||||
// Tests moved to tests/mcp_comprehensive_test.rs for better organization
|
||||
345
sgl-router/src/mcp/types.rs
Normal file
345
sgl-router/src/mcp/types.rs
Normal file
@@ -0,0 +1,345 @@
|
||||
// types.rs - All MCP data structures
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use thiserror::Error;
|
||||
use uuid;
|
||||
|
||||
// ===== Errors =====
|
||||
#[derive(Error, Debug)]
|
||||
pub enum MCPError {
|
||||
#[error("Connection failed: {0}")]
|
||||
ConnectionError(String),
|
||||
#[error("Invalid URL: {0}")]
|
||||
InvalidURL(String),
|
||||
#[error("Protocol error: {0}")]
|
||||
ProtocolError(String),
|
||||
#[error("Tool execution failed: {0}")]
|
||||
ToolExecutionError(String),
|
||||
#[error("Tool not found: {0}")]
|
||||
ToolNotFound(String),
|
||||
#[error("Serialization error: {0}")]
|
||||
SerializationError(String),
|
||||
#[error("Configuration error: {0}")]
|
||||
ConfigurationError(String),
|
||||
}
|
||||
|
||||
pub type MCPResult<T> = Result<T, MCPError>;
|
||||
|
||||
// Add From implementations for common error types
|
||||
impl From<serde_json::Error> for MCPError {
|
||||
fn from(err: serde_json::Error) -> Self {
|
||||
MCPError::SerializationError(err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<reqwest::Error> for MCPError {
|
||||
fn from(err: reqwest::Error) -> Self {
|
||||
MCPError::ConnectionError(err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
// ===== MCP Protocol Types =====
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MCPRequest {
|
||||
pub jsonrpc: String,
|
||||
pub id: String,
|
||||
pub method: String,
|
||||
pub params: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MCPResponse {
|
||||
pub jsonrpc: String,
|
||||
pub id: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub result: Option<serde_json::Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub error: Option<MCPErrorResponse>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MCPErrorResponse {
|
||||
pub code: i32,
|
||||
pub message: String,
|
||||
pub data: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
// ===== MCP Server Response Types =====
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct InitializeResponse {
|
||||
#[serde(rename = "serverInfo")]
|
||||
pub server_info: ServerInfo,
|
||||
pub instructions: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ServerInfo {
|
||||
pub name: String,
|
||||
pub version: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ListToolsResponse {
|
||||
pub tools: Vec<ToolInfo>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolInfo {
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
#[serde(rename = "inputSchema")]
|
||||
pub input_schema: serde_json::Value,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub annotations: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
// ===== Types =====
|
||||
pub type ToolCall = serde_json::Value; // Python uses dict
|
||||
pub type ToolResult = serde_json::Value; // Python uses dict
|
||||
|
||||
// ===== Connection Types =====
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HttpConnection {
|
||||
pub url: String,
|
||||
}
|
||||
|
||||
// ===== Tool Session =====
|
||||
pub struct ToolSession {
|
||||
pub connection: HttpConnection,
|
||||
pub client: reqwest::Client,
|
||||
pub session_initialized: bool,
|
||||
}
|
||||
|
||||
impl ToolSession {
|
||||
pub async fn new(connection_str: String) -> MCPResult<Self> {
|
||||
if !connection_str.starts_with("http://") && !connection_str.starts_with("https://") {
|
||||
return Err(MCPError::InvalidURL(format!(
|
||||
"Only HTTP/HTTPS URLs are supported: {}",
|
||||
connection_str
|
||||
)));
|
||||
}
|
||||
|
||||
let mut session = Self {
|
||||
connection: HttpConnection {
|
||||
url: connection_str,
|
||||
},
|
||||
client: reqwest::Client::new(),
|
||||
session_initialized: false,
|
||||
};
|
||||
|
||||
// Initialize the session
|
||||
session.initialize().await?;
|
||||
Ok(session)
|
||||
}
|
||||
|
||||
pub async fn new_http(url: String) -> MCPResult<Self> {
|
||||
Self::new(url).await
|
||||
}
|
||||
|
||||
/// Initialize the session
|
||||
pub async fn initialize(&mut self) -> MCPResult<()> {
|
||||
if self.session_initialized {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let init_request = MCPRequest {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id: "init".to_string(),
|
||||
method: "initialize".to_string(),
|
||||
params: Some(serde_json::json!({
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {}
|
||||
})),
|
||||
};
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&self.connection.url)
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&init_request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| MCPError::ConnectionError(format!("Initialize failed: {}", e)))?;
|
||||
|
||||
let mcp_response: MCPResponse = response.json().await.map_err(|e| {
|
||||
MCPError::SerializationError(format!("Failed to parse initialize response: {}", e))
|
||||
})?;
|
||||
|
||||
if let Some(error) = mcp_response.error {
|
||||
return Err(MCPError::ProtocolError(format!(
|
||||
"Initialize error: {}",
|
||||
error.message
|
||||
)));
|
||||
}
|
||||
|
||||
self.session_initialized = true;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Call a tool using MCP tools/call
|
||||
pub async fn call_tool(
|
||||
&self,
|
||||
name: &str,
|
||||
arguments: serde_json::Value,
|
||||
) -> MCPResult<serde_json::Value> {
|
||||
if !self.session_initialized {
|
||||
return Err(MCPError::ProtocolError(
|
||||
"Session not initialized. Call initialize() first.".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
use serde_json::json;
|
||||
|
||||
let request = MCPRequest {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id: format!("call_{}", uuid::Uuid::new_v4()),
|
||||
method: "tools/call".to_string(),
|
||||
params: Some(json!({
|
||||
"name": name,
|
||||
"arguments": arguments
|
||||
})),
|
||||
};
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&self.connection.url)
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| MCPError::ConnectionError(format!("Tool call failed: {}", e)))?;
|
||||
|
||||
let mcp_response: MCPResponse = response.json().await.map_err(|e| {
|
||||
MCPError::SerializationError(format!("Failed to parse tool response: {}", e))
|
||||
})?;
|
||||
|
||||
if let Some(error) = mcp_response.error {
|
||||
return Err(MCPError::ToolExecutionError(format!(
|
||||
"Tool '{}' failed: {}",
|
||||
name, error.message
|
||||
)));
|
||||
}
|
||||
|
||||
mcp_response
|
||||
.result
|
||||
.ok_or_else(|| MCPError::ProtocolError("No result in tool response".to_string()))
|
||||
}
|
||||
|
||||
/// Check if session is ready for tool calls
|
||||
pub fn is_ready(&self) -> bool {
|
||||
self.session_initialized
|
||||
}
|
||||
|
||||
/// Get connection info
|
||||
pub fn connection_info(&self) -> String {
|
||||
format!("HTTP: {}", self.connection.url)
|
||||
}
|
||||
}
|
||||
|
||||
// ===== Multi-Tool Session Manager =====
|
||||
pub struct MultiToolSessionManager {
|
||||
sessions: HashMap<String, ToolSession>, // server_url -> session
|
||||
tool_to_server: HashMap<String, String>, // tool_name -> server_url mapping
|
||||
}
|
||||
|
||||
impl Default for MultiToolSessionManager {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl MultiToolSessionManager {
|
||||
/// Create new multi-tool session manager
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
sessions: HashMap::new(),
|
||||
tool_to_server: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add tools from an MCP server (optimized to share sessions per server)
|
||||
pub async fn add_tools_from_server(
|
||||
&mut self,
|
||||
server_url: String,
|
||||
tool_names: Vec<String>,
|
||||
) -> MCPResult<()> {
|
||||
// Create one session per server URL (if not already exists)
|
||||
if !self.sessions.contains_key(&server_url) {
|
||||
let session = ToolSession::new(server_url.clone()).await?;
|
||||
self.sessions.insert(server_url.clone(), session);
|
||||
}
|
||||
|
||||
// Map all tools to this server URL
|
||||
for tool_name in tool_names {
|
||||
self.tool_to_server.insert(tool_name, server_url.clone());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get session for a specific tool
|
||||
pub fn get_session(&self, tool_name: &str) -> Option<&ToolSession> {
|
||||
let server_url = self.tool_to_server.get(tool_name)?;
|
||||
self.sessions.get(server_url)
|
||||
}
|
||||
|
||||
/// Execute tool with automatic session management
|
||||
pub async fn call_tool(
|
||||
&self,
|
||||
tool_name: &str,
|
||||
arguments: serde_json::Value,
|
||||
) -> MCPResult<serde_json::Value> {
|
||||
let server_url = self
|
||||
.tool_to_server
|
||||
.get(tool_name)
|
||||
.ok_or_else(|| MCPError::ToolNotFound(format!("No mapping for tool: {}", tool_name)))?;
|
||||
|
||||
let session = self.sessions.get(server_url).ok_or_else(|| {
|
||||
MCPError::ToolNotFound(format!("No session for server: {}", server_url))
|
||||
})?;
|
||||
|
||||
session.call_tool(tool_name, arguments).await
|
||||
}
|
||||
|
||||
/// Execute multiple tools concurrently
|
||||
pub async fn call_tools_concurrent(
|
||||
&self,
|
||||
tool_calls: Vec<(String, serde_json::Value)>,
|
||||
) -> Vec<MCPResult<serde_json::Value>> {
|
||||
let futures: Vec<_> = tool_calls
|
||||
.into_iter()
|
||||
.map(|(tool_name, args)| async move { self.call_tool(&tool_name, args).await })
|
||||
.collect();
|
||||
|
||||
futures::future::join_all(futures).await
|
||||
}
|
||||
|
||||
/// Get all available tool names
|
||||
pub fn list_tools(&self) -> Vec<String> {
|
||||
self.tool_to_server.keys().cloned().collect()
|
||||
}
|
||||
|
||||
/// Check if tool is available
|
||||
pub fn has_tool(&self, tool_name: &str) -> bool {
|
||||
self.tool_to_server.contains_key(tool_name)
|
||||
}
|
||||
|
||||
/// Get session statistics
|
||||
pub fn session_stats(&self) -> SessionStats {
|
||||
let total_sessions = self.sessions.len();
|
||||
let ready_sessions = self.sessions.values().filter(|s| s.is_ready()).count();
|
||||
let unique_servers = self.sessions.len(); // Now sessions = servers
|
||||
|
||||
SessionStats {
|
||||
total_sessions,
|
||||
ready_sessions,
|
||||
unique_servers,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SessionStats {
|
||||
pub total_sessions: usize,
|
||||
pub ready_sessions: usize,
|
||||
pub unique_servers: usize,
|
||||
}
|
||||
Reference in New Issue
Block a user