diff --git a/sgl-router/Cargo.toml b/sgl-router/Cargo.toml index fd4862054..4ecfae55d 100644 --- a/sgl-router/Cargo.toml +++ b/sgl-router/Cargo.toml @@ -55,6 +55,15 @@ tiktoken-rs = { version = "0.7.0" } minijinja = { version = "2.0" } rustls = { version = "0.23", default-features = false, features = ["ring", "std"] } hf-hub = { version = "0.4.3", features = ["tokio"] } +rmcp = { version = "0.6.3", features = ["client", "server", + "transport-child-process", + "transport-sse-client-reqwest", + "transport-streamable-http-client-reqwest", + "transport-streamable-http-server", + "transport-streamable-http-server-session", + "reqwest", + "auth"] } +serde_yaml = "0.9" # gRPC and Protobuf dependencies tonic = { version = "0.12", features = ["tls", "gzip", "transport"] } diff --git a/sgl-router/src/mcp/client_manager.rs b/sgl-router/src/mcp/client_manager.rs new file mode 100644 index 000000000..a2a6d7a7e --- /dev/null +++ b/sgl-router/src/mcp/client_manager.rs @@ -0,0 +1,535 @@ +use backoff::ExponentialBackoffBuilder; +use dashmap::DashMap; +use rmcp::{ + model::{ + CallToolRequestParam, GetPromptRequestParam, GetPromptResult, Prompt, + ReadResourceRequestParam, ReadResourceResult, Resource, Tool as McpTool, + }, + service::RunningService, + transport::{ + sse_client::SseClientConfig, streamable_http_client::StreamableHttpClientTransportConfig, + ConfigureCommandExt, SseClientTransport, StreamableHttpClientTransport, TokioChildProcess, + }, + RoleClient, ServiceExt, +}; +use serde::{Deserialize, Serialize}; +use std::{borrow::Cow, collections::HashMap, time::Duration}; + +use crate::mcp::{ + config::{McpConfig, McpServerConfig, McpTransport}, + error::{McpError, McpResult}, +}; + +/// Information about an available tool +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolInfo { + pub name: String, + pub description: String, + pub server: String, + pub parameters: Option, +} + +/// Information about an available prompt +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PromptInfo { + pub name: String, + pub description: Option, + pub server: String, + pub arguments: Option>, +} + +/// Information about an available resource +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ResourceInfo { + pub uri: String, + pub name: String, + pub description: Option, + pub mime_type: Option, + pub server: String, +} + +/// Manages MCP client connections and tool execution +pub struct McpClientManager { + /// Map of server_name -> MCP client + clients: HashMap>, + /// Map of tool_name -> (server_name, tool_definition) + tools: DashMap, + /// Map of prompt_name -> (server_name, prompt_definition) + prompts: DashMap, + /// Map of resource_uri -> (server_name, resource_definition) + resources: DashMap, +} + +impl McpClientManager { + /// Create a new manager and connect to all configured servers + pub async fn new(config: McpConfig) -> McpResult { + let mut mgr = Self { + clients: HashMap::new(), + tools: DashMap::new(), + prompts: DashMap::new(), + resources: DashMap::new(), + }; + + for server_config in config.servers { + match Self::connect_server(&server_config).await { + Ok(client) => { + mgr.load_server_inventory(&server_config.name, &client) + .await; + mgr.clients.insert(server_config.name.clone(), client); + } + Err(e) => { + tracing::error!( + "Failed to connect to server '{}': {}", + server_config.name, + e + ); + } + } + } + + if mgr.clients.is_empty() { + return Err(McpError::ConnectionFailed( + "Failed to connect to any MCP servers".to_string(), + )); + } + Ok(mgr) + } + + /// Discover and cache tools/prompts/resources for a connected server + async fn load_server_inventory( + &self, + server_name: &str, + client: &RunningService, + ) { + // Tools + match client.peer().list_all_tools().await { + Ok(ts) => { + tracing::info!("Discovered {} tools from '{}'", ts.len(), server_name); + for t in ts { + if self.tools.contains_key(t.name.as_ref()) { + tracing::warn!( + "Tool '{}' from server '{}' is overwriting an existing tool.", + &t.name, + server_name + ); + } + self.tools + .insert(t.name.to_string(), (server_name.to_string(), t)); + } + } + Err(e) => tracing::warn!("Failed to list tools from '{}': {}", server_name, e), + } + + // Prompts + match client.peer().list_all_prompts().await { + Ok(ps) => { + tracing::info!("Discovered {} prompts from '{}'", ps.len(), server_name); + for p in ps { + if self.prompts.contains_key(&p.name) { + tracing::warn!( + "Prompt '{}' from server '{}' is overwriting an existing prompt.", + &p.name, + server_name + ); + } + self.prompts + .insert(p.name.clone(), (server_name.to_string(), p)); + } + } + Err(e) => tracing::debug!("No prompts or failed to list on '{}': {}", server_name, e), + } + + // Resources + match client.peer().list_all_resources().await { + Ok(rs) => { + tracing::info!("Discovered {} resources from '{}'", rs.len(), server_name); + for r in rs { + if self.resources.contains_key(&r.uri) { + tracing::warn!( + "Resource '{}' from server '{}' is overwriting an existing resource.", + &r.uri, + server_name + ); + } + self.resources + .insert(r.uri.clone(), (server_name.to_string(), r)); + } + } + Err(e) => tracing::debug!("No resources or failed to list on '{}': {}", server_name, e), + } + } + + /// Connect to a single MCP server with retry logic for remote transports + async fn connect_server(config: &McpServerConfig) -> McpResult> { + let needs_retry = matches!( + &config.transport, + McpTransport::Sse { .. } | McpTransport::Streamable { .. } + ); + if needs_retry { + Self::connect_server_with_retry(config).await + } else { + Self::connect_server_impl(config).await + } + } + + /// Connect with exponential backoff retry for remote servers + async fn connect_server_with_retry( + config: &McpServerConfig, + ) -> McpResult> { + let backoff = ExponentialBackoffBuilder::new() + .with_initial_interval(Duration::from_secs(1)) + .with_max_interval(Duration::from_secs(30)) + .with_max_elapsed_time(Some(Duration::from_secs(120))) + .build(); + + backoff::future::retry(backoff, || async { + match Self::connect_server_impl(config).await { + Ok(client) => Ok(client), + Err(e) => { + tracing::warn!("Failed to connect to '{}', retrying: {}", config.name, e); + Err(backoff::Error::transient(e)) + } + } + }) + .await + } + + /// Internal implementation of server connection + async fn connect_server_impl( + config: &McpServerConfig, + ) -> McpResult> { + tracing::info!( + "Connecting to MCP server '{}' via {:?}", + config.name, + config.transport + ); + + match &config.transport { + McpTransport::Stdio { + command, + args, + envs, + } => { + let transport = TokioChildProcess::new( + tokio::process::Command::new(command).configure(|cmd| { + cmd.args(args) + .envs(envs.iter()) + .stderr(std::process::Stdio::inherit()); + }), + ) + .map_err(|e| McpError::Transport(format!("create stdio transport: {}", e)))?; + + let client = ().serve(transport).await.map_err(|e| { + McpError::ConnectionFailed(format!("initialize stdio client: {}", e)) + })?; + + tracing::info!("Connected to stdio server '{}'", config.name); + Ok(client) + } + + McpTransport::Sse { url, token } => { + let transport = if let Some(tok) = token { + let client = reqwest::Client::builder() + .default_headers({ + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + reqwest::header::AUTHORIZATION, + format!("Bearer {}", tok).parse().map_err(|e| { + McpError::Transport(format!("auth token: {}", e)) + })?, + ); + headers + }) + .build() + .map_err(|e| McpError::Transport(format!("build HTTP client: {}", e)))?; + + let cfg = SseClientConfig { + sse_endpoint: url.clone().into(), + ..Default::default() + }; + + SseClientTransport::start_with_client(client, cfg) + .await + .map_err(|e| McpError::Transport(format!("create SSE transport: {}", e)))? + } else { + SseClientTransport::start(url.as_str()) + .await + .map_err(|e| McpError::Transport(format!("create SSE transport: {}", e)))? + }; + + let client = ().serve(transport).await.map_err(|e| { + McpError::ConnectionFailed(format!("initialize SSE client: {}", e)) + })?; + + tracing::info!("Connected to SSE server '{}' at {}", config.name, url); + Ok(client) + } + + McpTransport::Streamable { url, token } => { + let transport = if let Some(tok) = token { + let mut cfg = StreamableHttpClientTransportConfig::with_uri(url.as_str()); + cfg.auth_header = Some(format!("Bearer {}", tok)); + StreamableHttpClientTransport::from_config(cfg) + } else { + StreamableHttpClientTransport::from_uri(url.as_str()) + }; + + let client = ().serve(transport).await.map_err(|e| { + McpError::ConnectionFailed(format!("initialize streamable client: {}", e)) + })?; + + tracing::info!( + "Connected to streamable HTTP server '{}' at {}", + config.name, + url + ); + Ok(client) + } + } + } + + // ===== Helpers ===== + + fn client_for(&self, server_name: &str) -> McpResult<&RunningService> { + self.clients + .get(server_name) + .ok_or_else(|| McpError::ServerNotFound(server_name.to_string())) + } + + fn tool_entry(&self, name: &str) -> McpResult<(String, McpTool)> { + self.tools + .get(name) + .map(|e| e.value().clone()) + .ok_or_else(|| McpError::ToolNotFound(name.to_string())) + } + + fn prompt_entry(&self, name: &str) -> McpResult<(String, Prompt)> { + self.prompts + .get(name) + .map(|e| e.value().clone()) + .ok_or_else(|| McpError::PromptNotFound(name.to_string())) + } + + fn resource_entry(&self, uri: &str) -> McpResult<(String, Resource)> { + self.resources + .get(uri) + .map(|e| e.value().clone()) + .ok_or_else(|| McpError::ResourceNotFound(uri.to_string())) + } + + // ===== Tool Methods ===== + + /// Call a tool by name + pub async fn call_tool( + &self, + tool_name: &str, + arguments: Option>, + ) -> McpResult { + let (server_name, _tool) = self.tool_entry(tool_name)?; + let client = self.client_for(&server_name)?; + + tracing::debug!("Calling tool '{}' on '{}'", tool_name, server_name); + + client + .peer() + .call_tool(CallToolRequestParam { + name: Cow::Owned(tool_name.to_string()), + arguments, + }) + .await + .map_err(|e| McpError::ToolExecution(format!("Tool call failed: {}", e))) + } + + /// Get all available tools + pub fn list_tools(&self) -> Vec { + self.tools + .iter() + .map(|entry| { + let tool_name = entry.key().clone(); + let (server_name, tool) = entry.value(); + ToolInfo { + name: tool_name, + description: tool.description.as_deref().unwrap_or_default().to_string(), + server: server_name.clone(), + parameters: Some(serde_json::Value::Object((*tool.input_schema).clone())), + } + }) + .collect() + } + + /// Get a specific tool by name + pub fn get_tool(&self, name: &str) -> Option { + self.tools.get(name).map(|entry| { + let (server_name, tool) = entry.value(); + ToolInfo { + name: name.to_string(), + description: tool.description.as_deref().unwrap_or_default().to_string(), + server: server_name.clone(), + parameters: Some(serde_json::Value::Object((*tool.input_schema).clone())), + } + }) + } + + /// Check if a tool exists + pub fn has_tool(&self, name: &str) -> bool { + self.tools.contains_key(name) + } + + /// Get list of connected servers + pub fn list_servers(&self) -> Vec { + self.clients.keys().cloned().collect() + } + + // ===== Prompt Methods ===== + + /// Get a prompt by name with arguments + pub async fn get_prompt( + &self, + prompt_name: &str, + arguments: Option>, + ) -> McpResult { + let (server_name, _prompt) = self.prompt_entry(prompt_name)?; + let client = self.client_for(&server_name)?; + + tracing::debug!("Getting prompt '{}' from '{}'", prompt_name, server_name); + + client + .peer() + .get_prompt(GetPromptRequestParam { + name: prompt_name.to_string(), + arguments, + }) + .await + .map_err(|e| McpError::ToolExecution(format!("Failed to get prompt: {}", e))) + } + + /// List all available prompts + pub fn list_prompts(&self) -> Vec { + self.prompts + .iter() + .map(|entry| { + let name = entry.key().clone(); + let (server_name, prompt) = entry.value(); + PromptInfo { + name, + description: prompt.description.clone(), + server: server_name.clone(), + arguments: prompt + .arguments + .clone() + .map(|args| args.into_iter().map(|arg| serde_json::json!(arg)).collect()), + } + }) + .collect() + } + + /// Get a specific prompt info by name + pub fn get_prompt_info(&self, name: &str) -> Option { + self.prompts.get(name).map(|entry| { + let (server_name, prompt) = entry.value(); + PromptInfo { + name: name.to_string(), + description: prompt.description.clone(), + server: server_name.clone(), + arguments: prompt + .arguments + .clone() + .map(|args| args.into_iter().map(|arg| serde_json::json!(arg)).collect()), + } + }) + } + + // ===== Resource Methods ===== + + /// Read a resource by URI + pub async fn read_resource(&self, uri: &str) -> McpResult { + let (server_name, _resource) = self.resource_entry(uri)?; + let client = self.client_for(&server_name)?; + + tracing::debug!("Reading resource '{}' from '{}'", uri, server_name); + + client + .peer() + .read_resource(ReadResourceRequestParam { + uri: uri.to_string(), + }) + .await + .map_err(|e| McpError::ToolExecution(format!("Failed to read resource: {}", e))) + } + + /// List all available resources + pub fn list_resources(&self) -> Vec { + self.resources + .iter() + .map(|entry| { + let uri = entry.key().clone(); + let (server_name, resource) = entry.value(); + ResourceInfo { + uri, + name: resource.name.clone(), + description: resource.description.clone(), + mime_type: resource.mime_type.clone(), + server: server_name.clone(), + } + }) + .collect() + } + + /// Get a specific resource info by URI + pub fn get_resource_info(&self, uri: &str) -> Option { + self.resources.get(uri).map(|entry| { + let (server_name, resource) = entry.value(); + ResourceInfo { + uri: uri.to_string(), + name: resource.name.clone(), + description: resource.description.clone(), + mime_type: resource.mime_type.clone(), + server: server_name.clone(), + } + }) + } + + /// Subscribe to resource changes + pub async fn subscribe_resource(&self, uri: &str) -> McpResult<()> { + let (server_name, _resource) = self.resource_entry(uri)?; + let client = self.client_for(&server_name)?; + + tracing::debug!("Subscribing to '{}' on '{}'", uri, server_name); + + client + .peer() + .subscribe(rmcp::model::SubscribeRequestParam { + uri: uri.to_string(), + }) + .await + .map_err(|e| McpError::ToolExecution(format!("Failed to subscribe: {}", e))) + } + + /// Unsubscribe from resource changes + pub async fn unsubscribe_resource(&self, uri: &str) -> McpResult<()> { + let (server_name, _resource) = self.resource_entry(uri)?; + let client = self.client_for(&server_name)?; + + tracing::debug!("Unsubscribing from '{}' on '{}'", uri, server_name); + + client + .peer() + .unsubscribe(rmcp::model::UnsubscribeRequestParam { + uri: uri.to_string(), + }) + .await + .map_err(|e| McpError::ToolExecution(format!("Failed to unsubscribe: {}", e))) + } + + /// Disconnect from all servers (for cleanup) + pub async fn shutdown(&mut self) { + for (name, client) in self.clients.drain() { + if let Err(e) = client.cancel().await { + tracing::warn!("Error disconnecting from '{}': {}", name, e); + } + } + self.tools.clear(); + self.prompts.clear(); + self.resources.clear(); + } +} diff --git a/sgl-router/src/mcp/config.rs b/sgl-router/src/mcp/config.rs new file mode 100644 index 000000000..1adf6a7d7 --- /dev/null +++ b/sgl-router/src/mcp/config.rs @@ -0,0 +1,52 @@ +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct McpConfig { + pub servers: Vec, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct McpServerConfig { + pub name: String, + #[serde(flatten)] + pub transport: McpTransport, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(tag = "protocol", rename_all = "lowercase")] +pub enum McpTransport { + Stdio { + command: String, + #[serde(default)] + args: Vec, + #[serde(default)] + envs: HashMap, + }, + Sse { + url: String, + #[serde(skip_serializing_if = "Option::is_none")] + token: Option, + }, + Streamable { + url: String, + #[serde(skip_serializing_if = "Option::is_none")] + token: Option, + }, +} + +impl McpConfig { + /// Load configuration from a YAML file + pub async fn from_file(path: &str) -> Result> { + let content = tokio::fs::read_to_string(path).await?; + let config: Self = serde_yaml::from_str(&content)?; + Ok(config) + } + + /// Load configuration from environment variables (optional) + pub fn from_env() -> Option { + // This could be expanded to read from env vars + // For now, return None to indicate env config not implemented + None + } +} diff --git a/sgl-router/src/mcp/error.rs b/sgl-router/src/mcp/error.rs new file mode 100644 index 000000000..03b8b4cd1 --- /dev/null +++ b/sgl-router/src/mcp/error.rs @@ -0,0 +1,42 @@ +use thiserror::Error; + +pub type McpResult = Result; + +#[derive(Debug, Error)] +pub enum McpError { + #[error("Server not found: {0}")] + ServerNotFound(String), + + #[error("Tool not found: {0}")] + ToolNotFound(String), + + #[error("Transport error: {0}")] + Transport(String), + + #[error("Tool execution failed: {0}")] + ToolExecution(String), + + #[error("Connection failed: {0}")] + ConnectionFailed(String), + + #[error("Configuration error: {0}")] + Config(String), + + #[error("Authentication error: {0}")] + Auth(String), + + #[error("Resource not found: {0}")] + ResourceNotFound(String), + + #[error("Prompt not found: {0}")] + PromptNotFound(String), + + #[error(transparent)] + Sdk(#[from] Box), + + #[error(transparent)] + Io(#[from] std::io::Error), + + #[error(transparent)] + Http(#[from] reqwest::Error), +} diff --git a/sgl-router/src/mcp/mod.rs b/sgl-router/src/mcp/mod.rs index 193a9d392..6cebc4c7d 100644 --- a/sgl-router/src/mcp/mod.rs +++ b/sgl-router/src/mcp/mod.rs @@ -1,9 +1,18 @@ -// mod.rs - MCP module exports -pub mod tool_server; -pub mod types; +// MCP Client for SGLang Router +// +// This module provides a complete MCP (Model Context Protocol) client implementation +// supporting multiple transport types (stdio, SSE, HTTP) and all MCP features: +// - Tools: Discovery and execution +// - Prompts: Reusable templates for LLM interactions +// - Resources: File/data access with subscription support +// - OAuth: Secure authentication for remote servers -pub use tool_server::{parse_sse_event, MCPToolServer, ToolStats}; -pub use types::{ - HttpConnection, MCPError, MCPResult, MultiToolSessionManager, SessionStats, ToolCall, - ToolResult, ToolSession, -}; +pub mod client_manager; +pub mod config; +pub mod error; +pub mod oauth; + +// Re-export the main types for convenience +pub use client_manager::{McpClientManager, PromptInfo, ResourceInfo, ToolInfo}; +pub use config::{McpConfig, McpServerConfig, McpTransport}; +pub use error::{McpError, McpResult}; diff --git a/sgl-router/src/mcp/oauth.rs b/sgl-router/src/mcp/oauth.rs new file mode 100644 index 000000000..3d13ea2be --- /dev/null +++ b/sgl-router/src/mcp/oauth.rs @@ -0,0 +1,191 @@ +// OAuth authentication support for MCP servers + +use axum::{ + extract::{Query, State}, + response::Html, + routing::get, + Router, +}; +use rmcp::transport::auth::OAuthState; +use serde::Deserialize; +use std::{net::SocketAddr, sync::Arc}; +use tokio::sync::{oneshot, Mutex}; + +use crate::mcp::error::{McpError, McpResult}; + +/// OAuth callback parameters +#[derive(Debug, Deserialize)] +struct CallbackParams { + code: String, + #[allow(dead_code)] + state: Option, +} + +/// State for the callback server +#[derive(Clone)] +struct CallbackState { + code_receiver: Arc>>>, +} + +/// HTML page returned after successful OAuth callback +const CALLBACK_HTML: &str = r#" + + + + OAuth Success + + + +
+
+

Authentication Successful!

+

You can now close this window and return to your application.

+
+ + +"#; + +/// OAuth authentication helper for MCP servers +pub struct OAuthHelper { + server_url: String, + redirect_uri: String, + callback_port: u16, +} + +impl OAuthHelper { + /// Create a new OAuth helper + pub fn new(server_url: String, redirect_uri: String, callback_port: u16) -> Self { + Self { + server_url, + redirect_uri, + callback_port, + } + } + + /// Perform OAuth authentication flow + pub async fn authenticate( + &self, + scopes: &[&str], + ) -> McpResult { + // Initialize OAuth state machine + let mut oauth_state = OAuthState::new(&self.server_url, None) + .await + .map_err(|e| McpError::Auth(format!("Failed to initialize OAuth: {}", e)))?; + + oauth_state + .start_authorization(scopes, &self.redirect_uri) + .await + .map_err(|e| McpError::Auth(format!("Failed to start authorization: {}", e)))?; + + // Get authorization URL + let auth_url = oauth_state + .get_authorization_url() + .await + .map_err(|e| McpError::Auth(format!("Failed to get authorization URL: {}", e)))?; + + tracing::info!("OAuth authorization URL: {}", auth_url); + + // Start callback server and wait for code + let auth_code = self.start_callback_server().await?; + + // Exchange code for token + oauth_state + .handle_callback(&auth_code) + .await + .map_err(|e| McpError::Auth(format!("Failed to handle OAuth callback: {}", e)))?; + + // Get authorization manager + oauth_state + .into_authorization_manager() + .ok_or_else(|| McpError::Auth("Failed to get authorization manager".to_string())) + } + + /// Start a local HTTP server to receive the OAuth callback + async fn start_callback_server(&self) -> McpResult { + let (code_sender, code_receiver) = oneshot::channel::(); + + let state = CallbackState { + code_receiver: Arc::new(Mutex::new(Some(code_sender))), + }; + + // Create router for callback + let app = Router::new() + .route("/callback", get(Self::callback_handler)) + .with_state(state); + + let addr = SocketAddr::from(([127, 0, 0, 1], self.callback_port)); + + // Start server in background + let listener = tokio::net::TcpListener::bind(addr).await.map_err(|e| { + McpError::Auth(format!( + "Failed to bind to callback port {}: {}", + self.callback_port, e + )) + })?; + + tokio::spawn(async move { + let _ = axum::serve(listener, app).await; + }); + + tracing::info!( + "OAuth callback server started on port {}", + self.callback_port + ); + + // Wait for authorization code + code_receiver + .await + .map_err(|_| McpError::Auth("Failed to receive authorization code".to_string())) + } + + /// Handle OAuth callback + async fn callback_handler( + Query(params): Query, + State(state): State, + ) -> Html { + tracing::debug!("Received OAuth callback with code"); + + // Send code to waiting task + if let Some(sender) = state.code_receiver.lock().await.take() { + let _ = sender.send(params.code); + } + + Html(CALLBACK_HTML.to_string()) + } +} + +/// Create an OAuth-authenticated client +pub async fn create_oauth_client( + server_url: String, + _sse_url: String, + redirect_uri: String, + callback_port: u16, + scopes: &[&str], +) -> McpResult> { + let helper = OAuthHelper::new(server_url, redirect_uri, callback_port); + let auth_manager = helper.authenticate(scopes).await?; + + let client = rmcp::transport::auth::AuthClient::new(reqwest::Client::default(), auth_manager); + + Ok(client) +} diff --git a/sgl-router/src/mcp/tool_server.rs b/sgl-router/src/mcp/tool_server.rs deleted file mode 100644 index d5bd905ba..000000000 --- a/sgl-router/src/mcp/tool_server.rs +++ /dev/null @@ -1,534 +0,0 @@ -// tool_server.rs - Main MCP implementation (matching Python's tool_server.py) -use crate::mcp::types::*; -use serde_json::{json, Value}; -use std::collections::HashMap; - -/// Main MCP Tool Server -pub struct MCPToolServer { - /// Tool descriptions by server - tool_descriptions: HashMap, - /// Server URLs - urls: HashMap, -} - -impl Default for MCPToolServer { - fn default() -> Self { - Self::new() - } -} - -impl MCPToolServer { - /// Create new MCPToolServer - pub fn new() -> Self { - Self { - tool_descriptions: HashMap::new(), - urls: HashMap::new(), - } - } - - /// Clears all existing tool servers and adds new ones from the provided URL(s). - /// URLs can be a single string or multiple comma-separated strings. - pub async fn add_tool_server(&mut self, server_url: String) -> MCPResult<()> { - let tool_urls: Vec<&str> = server_url.split(",").collect(); - let mut successful_connections = 0; - let mut errors = Vec::new(); - - // Clear existing - self.tool_descriptions = HashMap::new(); - self.urls = HashMap::new(); - - for url_str in tool_urls { - let url_str = url_str.trim(); - - // Format URL for MCP-compliant connection - let formatted_url = if url_str.starts_with("http://") || url_str.starts_with("https://") - { - url_str.to_string() - } else { - // Default to MCP endpoint if no protocol specified - format!("http://{}", url_str) - }; - - // Server connection with retry and error recovery - match self.connect_to_server(&formatted_url).await { - Ok((_init_response, tools_response)) => { - // Process tools with validation - let tools_obj = post_process_tools_description(tools_response); - - // Tool storage with conflict detection - for tool in &tools_obj.tools { - let tool_name = &tool.name; - - // Check for duplicate tools - if self.tool_descriptions.contains_key(tool_name) { - tracing::warn!( - "Tool {} already exists. Ignoring duplicate tool from server {}", - tool_name, - formatted_url - ); - continue; - } - - // Store individual tool descriptions - let tool_json = json!(tool); - self.tool_descriptions - .insert(tool_name.clone(), tool_json.clone()); - self.urls.insert(tool_name.clone(), formatted_url.clone()); - } - - successful_connections += 1; - } - Err(e) => { - errors.push(format!("Failed to connect to {}: {}", formatted_url, e)); - tracing::warn!("Failed to connect to MCP server {}: {}", formatted_url, e); - } - } - } - - // Error handling - succeed if at least one server connects - if successful_connections == 0 { - let combined_error = errors.join("; "); - return Err(MCPError::ConnectionError(format!( - "Failed to connect to any MCP servers: {}", - combined_error - ))); - } - - if !errors.is_empty() { - tracing::warn!("Some MCP servers failed to connect: {}", errors.join("; ")); - } - - tracing::info!( - "Successfully connected to {} MCP server(s), discovered {} tool(s)", - successful_connections, - self.tool_descriptions.len() - ); - - Ok(()) - } - - /// Server connection with retries (internal helper) - async fn connect_to_server( - &self, - url: &str, - ) -> MCPResult<(InitializeResponse, ListToolsResponse)> { - const MAX_RETRIES: u32 = 3; - const RETRY_DELAY_MS: u64 = 1000; - - let mut last_error = None; - - for attempt in 1..=MAX_RETRIES { - match list_server_and_tools(url).await { - Ok(result) => return Ok(result), - Err(e) => { - last_error = Some(e); - if attempt < MAX_RETRIES { - tracing::debug!( - "MCP server connection attempt {}/{} failed for {}: {}. Retrying...", - attempt, - MAX_RETRIES, - url, - last_error.as_ref().unwrap() - ); - tokio::time::sleep(tokio::time::Duration::from_millis( - RETRY_DELAY_MS * attempt as u64, - )) - .await; - } - } - } - } - - Err(last_error.unwrap()) - } - - /// Check if tool exists (matching Python's has_tool) - pub fn has_tool(&self, tool_name: &str) -> bool { - self.tool_descriptions.contains_key(tool_name) - } - - /// Get tool description (matching Python's get_tool_description) - pub fn get_tool_description(&self, tool_name: &str) -> Option<&Value> { - self.tool_descriptions.get(tool_name) - } - - /// Get tool session (matching Python's get_tool_session) - pub async fn get_tool_session(&self, tool_name: &str) -> MCPResult { - let url = self - .urls - .get(tool_name) - .ok_or_else(|| MCPError::ToolNotFound(tool_name.to_string()))?; - - // Create session - ToolSession::new(url.clone()).await - } - - /// Create multi-tool session manager - pub async fn create_multi_tool_session( - &self, - tool_names: Vec, - ) -> MCPResult { - let mut session_manager = MultiToolSessionManager::new(); - - // Group tools by server URL for efficient session creation - let mut server_tools: std::collections::HashMap> = - std::collections::HashMap::new(); - - for tool_name in tool_names { - if let Some(url) = self.urls.get(&tool_name) { - server_tools.entry(url.clone()).or_default().push(tool_name); - } else { - return Err(MCPError::ToolNotFound(format!( - "Tool not found: {}", - tool_name - ))); - } - } - - // Create sessions for each server - for (server_url, tools) in server_tools { - session_manager - .add_tools_from_server(server_url, tools) - .await?; - } - - Ok(session_manager) - } - - /// List all available tools - pub fn list_tools(&self) -> Vec { - self.tool_descriptions.keys().cloned().collect() - } - - /// Get tool statistics - pub fn get_tool_stats(&self) -> ToolStats { - ToolStats { - total_tools: self.tool_descriptions.len(), - total_servers: self - .urls - .values() - .collect::>() - .len(), - } - } - - /// List all connected servers - pub fn list_servers(&self) -> Vec { - self.urls - .values() - .cloned() - .collect::>() - .into_iter() - .collect() - } - - /// Check if a specific server is connected - pub fn has_server(&self, server_url: &str) -> bool { - self.urls.values().any(|url| url == server_url) - } - - /// Execute a tool directly (convenience method for simple usage) - pub async fn call_tool( - &self, - tool_name: &str, - arguments: serde_json::Value, - ) -> MCPResult { - let session = self.get_tool_session(tool_name).await?; - session.call_tool(tool_name, arguments).await - } - - /// Create a tool session from server URL (convenience method) - pub async fn create_session_from_url(&self, server_url: &str) -> MCPResult { - ToolSession::new(server_url.to_string()).await - } -} - -/// Tool statistics for monitoring -#[derive(Debug, Clone)] -pub struct ToolStats { - pub total_tools: usize, - pub total_servers: usize, -} - -/// MCP-compliant server connection using JSON-RPC over SSE -async fn list_server_and_tools( - server_url: &str, -) -> MCPResult<(InitializeResponse, ListToolsResponse)> { - // MCP specification: - // 1. Connect to MCP endpoint with GET (SSE) or POST (JSON-RPC) - // 2. Send initialize request - // 3. Send tools/list request - // 4. Parse JSON-RPC responses - - let client = reqwest::Client::new(); - - // Step 1: Send initialize request - let init_request = MCPRequest { - jsonrpc: "2.0".to_string(), - id: "1".to_string(), - method: "initialize".to_string(), - params: Some(json!({ - "protocolVersion": "2024-11-05", - "capabilities": {} - })), - }; - - let init_response = send_mcp_request(&client, server_url, init_request).await?; - let init_result: InitializeResponse = serde_json::from_value(init_response).map_err(|e| { - MCPError::SerializationError(format!("Failed to parse initialize response: {}", e)) - })?; - - // Step 2: Send tools/list request - let tools_request = MCPRequest { - jsonrpc: "2.0".to_string(), - id: "2".to_string(), - method: "tools/list".to_string(), - params: Some(json!({})), - }; - - let tools_response = send_mcp_request(&client, server_url, tools_request).await?; - let tools_result: ListToolsResponse = serde_json::from_value(tools_response).map_err(|e| { - MCPError::SerializationError(format!("Failed to parse tools/list response: {}", e)) - })?; - - Ok((init_result, tools_result)) -} - -/// Send MCP JSON-RPC request (supports both HTTP POST and SSE) -async fn send_mcp_request( - client: &reqwest::Client, - url: &str, - request: MCPRequest, -) -> MCPResult { - // Use HTTP POST for JSON-RPC requests - let response = client - .post(url) - .header("Content-Type", "application/json") - .header("Accept", "application/json") - .json(&request) - .send() - .await - .map_err(|e| MCPError::ConnectionError(format!("MCP request failed: {}", e)))?; - - if !response.status().is_success() { - return Err(MCPError::ProtocolError(format!( - "HTTP {}", - response.status() - ))); - } - - let mcp_response: MCPResponse = response.json().await.map_err(|e| { - MCPError::SerializationError(format!("Failed to parse MCP response: {}", e)) - })?; - - if let Some(error) = mcp_response.error { - return Err(MCPError::ProtocolError(format!( - "MCP error: {}", - error.message - ))); - } - - mcp_response - .result - .ok_or_else(|| MCPError::ProtocolError("No result in MCP response".to_string())) -} - -// Removed old send_http_request - now using send_mcp_request with proper MCP protocol - -/// Parse SSE event format (MCP-compliant JSON-RPC only) -pub fn parse_sse_event(event: &str) -> MCPResult> { - let mut data_lines = Vec::new(); - - for line in event.lines() { - if let Some(stripped) = line.strip_prefix("data: ") { - data_lines.push(stripped); - } - } - - if data_lines.is_empty() { - return Ok(None); - } - - let json_data = data_lines.join("\n"); - if json_data.trim().is_empty() { - return Ok(None); - } - - // Parse as MCP JSON-RPC response only (no custom events) - let mcp_response: MCPResponse = serde_json::from_str(&json_data).map_err(|e| { - MCPError::SerializationError(format!( - "Failed to parse JSON-RPC response: {} - Data: {}", - e, json_data - )) - })?; - - if let Some(error) = mcp_response.error { - return Err(MCPError::ProtocolError(error.message)); - } - - Ok(mcp_response.result) -} - -/// Schema adaptation matching Python's trim_schema() -fn trim_schema(schema: &mut Value) { - if let Some(obj) = schema.as_object_mut() { - // Remove title and null defaults - obj.remove("title"); - if obj.get("default") == Some(&Value::Null) { - obj.remove("default"); - } - - // Convert anyOf to type arrays - if let Some(any_of) = obj.remove("anyOf") { - if let Some(array) = any_of.as_array() { - let types: Vec = array - .iter() - .filter_map(|item| { - item.get("type") - .and_then(|t| t.as_str()) - .filter(|t| *t != "null") - .map(|t| t.to_string()) - }) - .collect(); - - // Handle single type vs array of types - match types.len() { - 0 => {} // No valid types found - 1 => { - obj.insert("type".to_string(), json!(types[0])); - } - _ => { - obj.insert("type".to_string(), json!(types)); - } - } - } - } - - // Handle oneOf similar to anyOf - if let Some(one_of) = obj.remove("oneOf") { - if let Some(array) = one_of.as_array() { - let types: Vec = array - .iter() - .filter_map(|item| { - item.get("type") - .and_then(|t| t.as_str()) - .filter(|t| *t != "null") - .map(|t| t.to_string()) - }) - .collect(); - - if !types.is_empty() { - obj.insert("type".to_string(), json!(types)); - } - } - } - - // Recursive processing for properties - if let Some(properties) = obj.get_mut("properties") { - if let Some(props_obj) = properties.as_object_mut() { - for (_, value) in props_obj.iter_mut() { - trim_schema(value); - } - } - } - - // Handle nested schemas in items (for arrays) - if let Some(items) = obj.get_mut("items") { - trim_schema(items); - } - - // Handle nested schemas in additionalProperties - if let Some(additional_props) = obj.get_mut("additionalProperties") { - if additional_props.is_object() { - trim_schema(additional_props); - } - } - - // Handle patternProperties (for dynamic property names) - if let Some(pattern_props) = obj.get_mut("patternProperties") { - if let Some(pattern_obj) = pattern_props.as_object_mut() { - for (_, value) in pattern_obj.iter_mut() { - trim_schema(value); - } - } - } - - // Handle allOf in nested contexts - if let Some(all_of) = obj.get_mut("allOf") { - if let Some(array) = all_of.as_array_mut() { - for item in array.iter_mut() { - trim_schema(item); - } - } - } - } -} - -/// Tool processing with filtering -fn post_process_tools_description(mut tools_response: ListToolsResponse) -> ListToolsResponse { - // Adapt schemas for Harmony - for tool in &mut tools_response.tools { - trim_schema(&mut tool.input_schema); - } - - // Tool filtering based on annotations - let initial_count = tools_response.tools.len(); - - tools_response.tools.retain(|tool| { - // Check include_in_prompt annotation (Python behavior) - let include_in_prompt = tool - .annotations - .as_ref() - .and_then(|a| a.get("include_in_prompt")) - .and_then(|v| v.as_bool()) - .unwrap_or(true); - - if !include_in_prompt { - tracing::debug!( - "Filtering out tool '{}' due to include_in_prompt=false", - tool.name - ); - return false; - } - - // Check if tool is explicitly disabled - let disabled = tool - .annotations - .as_ref() - .and_then(|a| a.get("disabled")) - .and_then(|v| v.as_bool()) - .unwrap_or(false); - - if disabled { - tracing::debug!("Filtering out disabled tool '{}'", tool.name); - return false; - } - - // Validate tool has required fields - if tool.name.trim().is_empty() { - tracing::warn!("Filtering out tool with empty name"); - return false; - } - - // Check for valid input schema - if tool.input_schema.is_null() { - tracing::warn!("Tool '{}' has null input schema, but keeping it", tool.name); - } - - true - }); - - let filtered_count = tools_response.tools.len(); - if filtered_count != initial_count { - tracing::info!( - "Filtered tools: {} -> {} ({} removed)", - initial_count, - filtered_count, - initial_count - filtered_count - ); - } - - tools_response -} - -// Tests moved to tests/mcp_comprehensive_test.rs for better organization diff --git a/sgl-router/src/mcp/types.rs b/sgl-router/src/mcp/types.rs deleted file mode 100644 index 7eef6b826..000000000 --- a/sgl-router/src/mcp/types.rs +++ /dev/null @@ -1,345 +0,0 @@ -// types.rs - All MCP data structures -use serde::{Deserialize, Serialize}; -use std::collections::HashMap; -use thiserror::Error; -use uuid; - -// ===== Errors ===== -#[derive(Error, Debug)] -pub enum MCPError { - #[error("Connection failed: {0}")] - ConnectionError(String), - #[error("Invalid URL: {0}")] - InvalidURL(String), - #[error("Protocol error: {0}")] - ProtocolError(String), - #[error("Tool execution failed: {0}")] - ToolExecutionError(String), - #[error("Tool not found: {0}")] - ToolNotFound(String), - #[error("Serialization error: {0}")] - SerializationError(String), - #[error("Configuration error: {0}")] - ConfigurationError(String), -} - -pub type MCPResult = Result; - -// Add From implementations for common error types -impl From for MCPError { - fn from(err: serde_json::Error) -> Self { - MCPError::SerializationError(err.to_string()) - } -} - -impl From for MCPError { - fn from(err: reqwest::Error) -> Self { - MCPError::ConnectionError(err.to_string()) - } -} - -// ===== MCP Protocol Types ===== -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct MCPRequest { - pub jsonrpc: String, - pub id: String, - pub method: String, - pub params: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct MCPResponse { - pub jsonrpc: String, - pub id: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub result: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub error: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct MCPErrorResponse { - pub code: i32, - pub message: String, - pub data: Option, -} - -// ===== MCP Server Response Types ===== -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct InitializeResponse { - #[serde(rename = "serverInfo")] - pub server_info: ServerInfo, - pub instructions: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ServerInfo { - pub name: String, - pub version: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ListToolsResponse { - pub tools: Vec, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ToolInfo { - pub name: String, - pub description: Option, - #[serde(rename = "inputSchema")] - pub input_schema: serde_json::Value, - #[serde(skip_serializing_if = "Option::is_none")] - pub annotations: Option, -} - -// ===== Types ===== -pub type ToolCall = serde_json::Value; // Python uses dict -pub type ToolResult = serde_json::Value; // Python uses dict - -// ===== Connection Types ===== -#[derive(Debug, Clone)] -pub struct HttpConnection { - pub url: String, -} - -// ===== Tool Session ===== -pub struct ToolSession { - pub connection: HttpConnection, - pub client: reqwest::Client, - pub session_initialized: bool, -} - -impl ToolSession { - pub async fn new(connection_str: String) -> MCPResult { - if !connection_str.starts_with("http://") && !connection_str.starts_with("https://") { - return Err(MCPError::InvalidURL(format!( - "Only HTTP/HTTPS URLs are supported: {}", - connection_str - ))); - } - - let mut session = Self { - connection: HttpConnection { - url: connection_str, - }, - client: reqwest::Client::new(), - session_initialized: false, - }; - - // Initialize the session - session.initialize().await?; - Ok(session) - } - - pub async fn new_http(url: String) -> MCPResult { - Self::new(url).await - } - - /// Initialize the session - pub async fn initialize(&mut self) -> MCPResult<()> { - if self.session_initialized { - return Ok(()); - } - - let init_request = MCPRequest { - jsonrpc: "2.0".to_string(), - id: "init".to_string(), - method: "initialize".to_string(), - params: Some(serde_json::json!({ - "protocolVersion": "2024-11-05", - "capabilities": {} - })), - }; - - let response = self - .client - .post(&self.connection.url) - .header("Content-Type", "application/json") - .json(&init_request) - .send() - .await - .map_err(|e| MCPError::ConnectionError(format!("Initialize failed: {}", e)))?; - - let mcp_response: MCPResponse = response.json().await.map_err(|e| { - MCPError::SerializationError(format!("Failed to parse initialize response: {}", e)) - })?; - - if let Some(error) = mcp_response.error { - return Err(MCPError::ProtocolError(format!( - "Initialize error: {}", - error.message - ))); - } - - self.session_initialized = true; - Ok(()) - } - - /// Call a tool using MCP tools/call - pub async fn call_tool( - &self, - name: &str, - arguments: serde_json::Value, - ) -> MCPResult { - if !self.session_initialized { - return Err(MCPError::ProtocolError( - "Session not initialized. Call initialize() first.".to_string(), - )); - } - - use serde_json::json; - - let request = MCPRequest { - jsonrpc: "2.0".to_string(), - id: format!("call_{}", uuid::Uuid::new_v4()), - method: "tools/call".to_string(), - params: Some(json!({ - "name": name, - "arguments": arguments - })), - }; - - let response = self - .client - .post(&self.connection.url) - .header("Content-Type", "application/json") - .json(&request) - .send() - .await - .map_err(|e| MCPError::ConnectionError(format!("Tool call failed: {}", e)))?; - - let mcp_response: MCPResponse = response.json().await.map_err(|e| { - MCPError::SerializationError(format!("Failed to parse tool response: {}", e)) - })?; - - if let Some(error) = mcp_response.error { - return Err(MCPError::ToolExecutionError(format!( - "Tool '{}' failed: {}", - name, error.message - ))); - } - - mcp_response - .result - .ok_or_else(|| MCPError::ProtocolError("No result in tool response".to_string())) - } - - /// Check if session is ready for tool calls - pub fn is_ready(&self) -> bool { - self.session_initialized - } - - /// Get connection info - pub fn connection_info(&self) -> String { - format!("HTTP: {}", self.connection.url) - } -} - -// ===== Multi-Tool Session Manager ===== -pub struct MultiToolSessionManager { - sessions: HashMap, // server_url -> session - tool_to_server: HashMap, // tool_name -> server_url mapping -} - -impl Default for MultiToolSessionManager { - fn default() -> Self { - Self::new() - } -} - -impl MultiToolSessionManager { - /// Create new multi-tool session manager - pub fn new() -> Self { - Self { - sessions: HashMap::new(), - tool_to_server: HashMap::new(), - } - } - - /// Add tools from an MCP server (optimized to share sessions per server) - pub async fn add_tools_from_server( - &mut self, - server_url: String, - tool_names: Vec, - ) -> MCPResult<()> { - // Create one session per server URL (if not already exists) - if !self.sessions.contains_key(&server_url) { - let session = ToolSession::new(server_url.clone()).await?; - self.sessions.insert(server_url.clone(), session); - } - - // Map all tools to this server URL - for tool_name in tool_names { - self.tool_to_server.insert(tool_name, server_url.clone()); - } - Ok(()) - } - - /// Get session for a specific tool - pub fn get_session(&self, tool_name: &str) -> Option<&ToolSession> { - let server_url = self.tool_to_server.get(tool_name)?; - self.sessions.get(server_url) - } - - /// Execute tool with automatic session management - pub async fn call_tool( - &self, - tool_name: &str, - arguments: serde_json::Value, - ) -> MCPResult { - let server_url = self - .tool_to_server - .get(tool_name) - .ok_or_else(|| MCPError::ToolNotFound(format!("No mapping for tool: {}", tool_name)))?; - - let session = self.sessions.get(server_url).ok_or_else(|| { - MCPError::ToolNotFound(format!("No session for server: {}", server_url)) - })?; - - session.call_tool(tool_name, arguments).await - } - - /// Execute multiple tools concurrently - pub async fn call_tools_concurrent( - &self, - tool_calls: Vec<(String, serde_json::Value)>, - ) -> Vec> { - let futures: Vec<_> = tool_calls - .into_iter() - .map(|(tool_name, args)| async move { self.call_tool(&tool_name, args).await }) - .collect(); - - futures::future::join_all(futures).await - } - - /// Get all available tool names - pub fn list_tools(&self) -> Vec { - self.tool_to_server.keys().cloned().collect() - } - - /// Check if tool is available - pub fn has_tool(&self, tool_name: &str) -> bool { - self.tool_to_server.contains_key(tool_name) - } - - /// Get session statistics - pub fn session_stats(&self) -> SessionStats { - let total_sessions = self.sessions.len(); - let ready_sessions = self.sessions.values().filter(|s| s.is_ready()).count(); - let unique_servers = self.sessions.len(); // Now sessions = servers - - SessionStats { - total_sessions, - ready_sessions, - unique_servers, - } - } -} - -#[derive(Debug, Clone)] -pub struct SessionStats { - pub total_sessions: usize, - pub ready_sessions: usize, - pub unique_servers: usize, -} diff --git a/sgl-router/tests/common/mock_mcp_server.rs b/sgl-router/tests/common/mock_mcp_server.rs index b5b2fd244..6a2dd498d 100644 --- a/sgl-router/tests/common/mock_mcp_server.rs +++ b/sgl-router/tests/common/mock_mcp_server.rs @@ -1,9 +1,14 @@ // tests/common/mock_mcp_server.rs - Mock MCP server for testing - -use axum::{ - extract::Json, http::StatusCode, response::Json as ResponseJson, routing::post, Router, +use rmcp::{ + handler::server::{router::tool::ToolRouter, wrapper::Parameters}, + model::*, + service::RequestContext, + tool, tool_handler, tool_router, + transport::streamable_http_server::{ + session::local::LocalSessionManager, StreamableHttpService, + }, + ErrorData as McpError, RoleServer, ServerHandler, }; -use serde_json::{json, Value}; use tokio::net::TcpListener; /// Mock MCP server that returns hardcoded responses for testing @@ -12,6 +17,69 @@ pub struct MockMCPServer { pub server_handle: Option>, } +/// Simple test server with mock search tools +#[derive(Clone)] +pub struct MockSearchServer { + tool_router: ToolRouter, +} + +#[tool_router] +impl MockSearchServer { + pub fn new() -> Self { + Self { + tool_router: Self::tool_router(), + } + } + + #[tool(description = "Mock web search tool")] + fn brave_web_search( + &self, + Parameters(params): Parameters>, + ) -> Result { + let query = params + .get("query") + .and_then(|v| v.as_str()) + .unwrap_or("test"); + Ok(CallToolResult::success(vec![Content::text(format!( + "Mock search results for: {}", + query + ))])) + } + + #[tool(description = "Mock local search tool")] + fn brave_local_search( + &self, + Parameters(_params): Parameters>, + ) -> Result { + Ok(CallToolResult::success(vec![Content::text( + "Mock local search results", + )])) + } +} + +#[tool_handler] +impl ServerHandler for MockSearchServer { + fn get_info(&self) -> ServerInfo { + ServerInfo { + protocol_version: ProtocolVersion::V_2024_11_05, + capabilities: ServerCapabilities::builder().enable_tools().build(), + server_info: Implementation { + name: "Mock MCP Server".to_string(), + version: "1.0.0".to_string(), + }, + instructions: Some("Mock server for testing".to_string()), + } + } + + async fn initialize( + &self, + _request: InitializeRequestParam, + _context: RequestContext, + ) -> Result { + Ok(self.get_info()) + } +} + impl MockMCPServer { /// Start a mock MCP server on an available port pub async fn start() -> Result> { @@ -19,7 +87,14 @@ impl MockMCPServer { let listener = TcpListener::bind("127.0.0.1:0").await?; let port = listener.local_addr()?.port(); - let app = Router::new().route("/mcp", post(handle_mcp_request)); + // Create the MCP service using rmcp's StreamableHttpService + let service = StreamableHttpService::new( + || Ok(MockSearchServer::new()), + LocalSessionManager::default().into(), + Default::default(), + ); + + let app = axum::Router::new().nest_service("/mcp", service); let server_handle = tokio::spawn(async move { axum::serve(listener, app) @@ -59,142 +134,10 @@ impl Drop for MockMCPServer { } } -/// Handle MCP requests and return mock responses -async fn handle_mcp_request(Json(request): Json) -> Result, StatusCode> { - // Parse the JSON-RPC request - let method = request.get("method").and_then(|m| m.as_str()).unwrap_or(""); - - let id = request - .get("id") - .and_then(|i| i.as_str()) - .unwrap_or("unknown"); - - let response = match method { - "initialize" => { - // Mock initialize response - json!({ - "jsonrpc": "2.0", - "id": id, - "result": { - "serverInfo": { - "name": "Mock MCP Server", - "version": "1.0.0" - }, - "instructions": "Mock server for testing" - } - }) - } - "tools/list" => { - // Mock tools list response - json!({ - "jsonrpc": "2.0", - "id": id, - "result": { - "tools": [ - { - "name": "brave_web_search", - "description": "Mock web search tool", - "inputSchema": { - "type": "object", - "properties": { - "query": {"type": "string"}, - "count": {"type": "integer"} - }, - "required": ["query"] - } - }, - { - "name": "brave_local_search", - "description": "Mock local search tool", - "inputSchema": { - "type": "object", - "properties": { - "query": {"type": "string"} - }, - "required": ["query"] - } - } - ] - } - }) - } - "tools/call" => { - // Mock tool call response - let empty_json = json!({}); - let params = request.get("params").unwrap_or(&empty_json); - let tool_name = params.get("name").and_then(|n| n.as_str()).unwrap_or(""); - let empty_args = json!({}); - let arguments = params.get("arguments").unwrap_or(&empty_args); - - match tool_name { - "brave_web_search" => { - let query = arguments - .get("query") - .and_then(|q| q.as_str()) - .unwrap_or("test"); - json!({ - "jsonrpc": "2.0", - "id": id, - "result": { - "content": [ - { - "type": "text", - "text": format!("Mock search results for: {}", query) - } - ], - "isError": false - } - }) - } - "brave_local_search" => { - json!({ - "jsonrpc": "2.0", - "id": id, - "result": { - "content": [ - { - "type": "text", - "text": "Mock local search results" - } - ], - "isError": false - } - }) - } - _ => { - // Unknown tool - json!({ - "jsonrpc": "2.0", - "id": id, - "error": { - "code": -1, - "message": format!("Unknown tool: {}", tool_name) - } - }) - } - } - } - _ => { - // Unknown method - json!({ - "jsonrpc": "2.0", - "id": id, - "error": { - "code": -32601, - "message": format!("Method not found: {}", method) - } - }) - } - }; - - Ok(ResponseJson(response)) -} - #[cfg(test)] -#[allow(unused_imports)] mod tests { + #[allow(unused_imports)] use super::MockMCPServer; - use serde_json::{json, Value}; #[tokio::test] async fn test_mock_server_startup() { @@ -205,32 +148,32 @@ mod tests { } #[tokio::test] - async fn test_mock_server_responses() { + async fn test_mock_server_with_rmcp_client() { let mut server = MockMCPServer::start().await.unwrap(); - let client = reqwest::Client::new(); - // Test initialize - let init_request = json!({ - "jsonrpc": "2.0", - "id": "1", - "method": "initialize", - "params": { - "protocolVersion": "2024-11-05", - "capabilities": {} + // Test that we can connect with rmcp client + use rmcp::transport::StreamableHttpClientTransport; + use rmcp::ServiceExt; + + let transport = StreamableHttpClientTransport::from_uri(server.url().as_str()); + let client = ().serve(transport).await; + + assert!(client.is_ok(), "Should be able to connect to mock server"); + + if let Ok(client) = client { + // Test listing tools + let tools = client.peer().list_all_tools().await; + assert!(tools.is_ok(), "Should be able to list tools"); + + if let Ok(tools) = tools { + assert_eq!(tools.len(), 2, "Should have 2 tools"); + assert!(tools.iter().any(|t| t.name == "brave_web_search")); + assert!(tools.iter().any(|t| t.name == "brave_local_search")); } - }); - let response = client - .post(server.url()) - .json(&init_request) - .send() - .await - .unwrap(); - - assert!(response.status().is_success()); - let json: Value = response.json().await.unwrap(); - assert_eq!(json["jsonrpc"], "2.0"); - assert_eq!(json["result"]["serverInfo"]["name"], "Mock MCP Server"); + // Shutdown by dropping the client + drop(client); + } server.stop().await; } diff --git a/sgl-router/tests/mcp_test.rs b/sgl-router/tests/mcp_test.rs index 15e825b7a..9821bffa6 100644 --- a/sgl-router/tests/mcp_test.rs +++ b/sgl-router/tests/mcp_test.rs @@ -2,18 +2,19 @@ // functionality required for SGLang responses API integration. // // Test Coverage: -// - Core MCP server functionality (Python tool_server.py parity) +// - Core MCP server functionality // - Tool session management (individual and multi-tool) // - Tool execution and error handling // - Schema adaptation and validation -// - SSE parsing and protocol compliance // - Mock server integration for reliable testing mod common; use common::mock_mcp_server::MockMCPServer; use serde_json::json; -use sglang_router_rs::mcp::{parse_sse_event, MCPToolServer, MultiToolSessionManager, ToolSession}; +use sglang_router_rs::mcp::{McpClientManager, McpConfig, McpError, McpServerConfig, McpTransport}; +use std::collections::HashMap; + /// Create a new mock server for testing (each test gets its own) async fn create_mock_server() -> MockMCPServer { MockMCPServer::start() @@ -21,49 +22,69 @@ async fn create_mock_server() -> MockMCPServer { .expect("Failed to start mock MCP server") } -// Core MCP Server Tests (Python parity) +// Core MCP Server Tests #[tokio::test] async fn test_mcp_server_initialization() { - let server = MCPToolServer::new(); + // Test that we can create an empty configuration + let config = McpConfig { servers: vec![] }; - assert!(!server.has_tool("any_tool")); - assert_eq!(server.list_tools().len(), 0); - assert_eq!(server.list_servers().len(), 0); - - let stats = server.get_tool_stats(); - assert_eq!(stats.total_tools, 0); - assert_eq!(stats.total_servers, 0); + // Should fail with no servers + let result = McpClientManager::new(config).await; + assert!(result.is_err(), "Should fail with no servers configured"); } #[tokio::test] async fn test_server_connection_with_mock() { let mock_server = create_mock_server().await; - let mut mcp_server = MCPToolServer::new(); - let result = mcp_server.add_tool_server(mock_server.url()).await; + let config = McpConfig { + servers: vec![McpServerConfig { + name: "mock_server".to_string(), + transport: McpTransport::Streamable { + url: mock_server.url(), + token: None, + }, + }], + }; + + let result = McpClientManager::new(config).await; assert!(result.is_ok(), "Should connect to mock server"); - let stats = mcp_server.get_tool_stats(); - assert_eq!(stats.total_tools, 2); - assert_eq!(stats.total_servers, 1); + let mut manager = result.unwrap(); - assert!(mcp_server.has_tool("brave_web_search")); - assert!(mcp_server.has_tool("brave_local_search")); + let servers = manager.list_servers(); + assert_eq!(servers.len(), 1); + assert!(servers.contains(&"mock_server".to_string())); + + let tools = manager.list_tools(); + assert_eq!(tools.len(), 2, "Should have 2 tools from mock server"); + + assert!(manager.has_tool("brave_web_search")); + assert!(manager.has_tool("brave_local_search")); + + manager.shutdown().await; } #[tokio::test] async fn test_tool_availability_checking() { let mock_server = create_mock_server().await; - let mut mcp_server = MCPToolServer::new(); - assert!(!mcp_server.has_tool("brave_web_search")); + let config = McpConfig { + servers: vec![McpServerConfig { + name: "mock_server".to_string(), + transport: McpTransport::Streamable { + url: mock_server.url(), + token: None, + }, + }], + }; - mcp_server.add_tool_server(mock_server.url()).await.unwrap(); + let mut manager = McpClientManager::new(config).await.unwrap(); let test_tools = vec!["brave_web_search", "brave_local_search", "calculator"]; for tool in test_tools { - let available = mcp_server.has_tool(tool); + let available = manager.has_tool(tool); match tool { "brave_web_search" | "brave_local_search" => { assert!( @@ -82,90 +103,77 @@ async fn test_tool_availability_checking() { _ => {} } } + + manager.shutdown().await; } #[tokio::test] -async fn test_multi_server_url_parsing() { +async fn test_multi_server_connection() { let mock_server1 = create_mock_server().await; let mock_server2 = create_mock_server().await; - let mut mcp_server = MCPToolServer::new(); - let combined_urls = format!("{},{}", mock_server1.url(), mock_server2.url()); - let result = mcp_server.add_tool_server(combined_urls).await; - assert!(result.is_ok(), "Should connect to multiple servers"); + let config = McpConfig { + servers: vec![ + McpServerConfig { + name: "mock_server_1".to_string(), + transport: McpTransport::Streamable { + url: mock_server1.url(), + token: None, + }, + }, + McpServerConfig { + name: "mock_server_2".to_string(), + transport: McpTransport::Streamable { + url: mock_server2.url(), + token: None, + }, + }, + ], + }; - let stats = mcp_server.get_tool_stats(); - assert!(stats.total_servers >= 1); - assert!(stats.total_tools >= 2); -} + // Note: This will fail to connect to both servers in the current implementation + // since they return the same tools. The manager will connect to the first one. + let result = McpClientManager::new(config).await; -// Tool Session Management Tests + if let Ok(mut manager) = result { + let servers = manager.list_servers(); + assert!(!servers.is_empty(), "Should have at least one server"); -#[tokio::test] -async fn test_individual_tool_session_creation() { - let mock_server = create_mock_server().await; - let mut mcp_server = MCPToolServer::new(); + let tools = manager.list_tools(); + assert!(tools.len() >= 2, "Should have tools from servers"); - mcp_server.add_tool_server(mock_server.url()).await.unwrap(); - - let session_result = mcp_server.get_tool_session("brave_web_search").await; - assert!(session_result.is_ok(), "Should create tool session"); - - let session = session_result.unwrap(); - assert!(session.is_ready(), "Session should be ready"); - assert!(session.connection_info().contains("HTTP")); -} - -#[tokio::test] -async fn test_multi_tool_session_manager() { - let mock_server = create_mock_server().await; - let mut mcp_server = MCPToolServer::new(); - - mcp_server.add_tool_server(mock_server.url()).await.unwrap(); - let available_tools = mcp_server.list_tools(); - assert!( - !available_tools.is_empty(), - "Should have tools from mock server" - ); - - let session_manager_result = mcp_server - .create_multi_tool_session(available_tools.clone()) - .await; - assert!( - session_manager_result.is_ok(), - "Should create session manager" - ); - - let session_manager = session_manager_result.unwrap(); - - for tool in &available_tools { - assert!(session_manager.has_tool(tool)); + manager.shutdown().await; } - - let stats = session_manager.session_stats(); - // After optimization: 1 session per server (not per tool) - assert_eq!(stats.total_sessions, 1); // One session for the mock server - assert_eq!(stats.ready_sessions, 1); // One ready session - assert_eq!(stats.unique_servers, 1); // One unique server - - // But we still have all tools available - assert_eq!(session_manager.list_tools().len(), available_tools.len()); } #[tokio::test] async fn test_tool_execution_with_mock() { let mock_server = create_mock_server().await; - let mut mcp_server = MCPToolServer::new(); - mcp_server.add_tool_server(mock_server.url()).await.unwrap(); + let config = McpConfig { + servers: vec![McpServerConfig { + name: "mock_server".to_string(), + transport: McpTransport::Streamable { + url: mock_server.url(), + token: None, + }, + }], + }; - let result = mcp_server + let mut manager = McpClientManager::new(config).await.unwrap(); + + let result = manager .call_tool( "brave_web_search", - json!({ - "query": "rust programming", - "count": 1 - }), + Some( + json!({ + "query": "rust programming", + "count": 1 + }) + .as_object() + .unwrap() + .clone(), + ), ) .await; @@ -175,48 +183,53 @@ async fn test_tool_execution_with_mock() { ); let response = result.unwrap(); - assert!( - response.get("content").is_some(), - "Response should have content" - ); - assert_eq!(response.get("isError").unwrap(), false); + assert!(!response.content.is_empty(), "Should have content"); - let content = response.get("content").unwrap().as_array().unwrap(); - let text = content[0].get("text").unwrap().as_str().unwrap(); - assert!(text.contains("Mock search results for: rust programming")); + // Check the content + if let rmcp::model::RawContent::Text(text) = &response.content[0].raw { + assert!(text + .text + .contains("Mock search results for: rust programming")); + } else { + panic!("Expected text content"); + } + + manager.shutdown().await; } #[tokio::test] async fn test_concurrent_tool_execution() { let mock_server = create_mock_server().await; - let mut session_manager = MultiToolSessionManager::new(); - session_manager - .add_tools_from_server( - mock_server.url(), - vec![ - "brave_web_search".to_string(), - "brave_local_search".to_string(), - ], - ) - .await - .unwrap(); + let config = McpConfig { + servers: vec![McpServerConfig { + name: "mock_server".to_string(), + transport: McpTransport::Streamable { + url: mock_server.url(), + token: None, + }, + }], + }; + let mut manager = McpClientManager::new(config).await.unwrap(); + + // Execute tools sequentially (true concurrent execution would require Arc) let tool_calls = vec![ - ("brave_web_search".to_string(), json!({"query": "test1"})), - ("brave_local_search".to_string(), json!({"query": "test2"})), + ("brave_web_search", json!({"query": "test1"})), + ("brave_local_search", json!({"query": "test2"})), ]; - let results = session_manager.call_tools_concurrent(tool_calls).await; - assert_eq!(results.len(), 2, "Should return results for both tools"); + for (tool_name, args) in tool_calls { + let result = manager + .call_tool(tool_name, Some(args.as_object().unwrap().clone())) + .await; - for (i, result) in results.iter().enumerate() { - assert!(result.is_ok(), "Tool {} should succeed with mock server", i); - - let response = result.as_ref().unwrap(); - assert!(response.get("content").is_some()); - assert_eq!(response.get("isError").unwrap(), false); + assert!(result.is_ok(), "Tool {} should succeed", tool_name); + let response = result.unwrap(); + assert!(!response.content.is_empty(), "Should have content"); } + + manager.shutdown().await; } // Error Handling Tests @@ -224,235 +237,221 @@ async fn test_concurrent_tool_execution() { #[tokio::test] async fn test_tool_execution_errors() { let mock_server = create_mock_server().await; - let mut mcp_server = MCPToolServer::new(); - mcp_server.add_tool_server(mock_server.url()).await.unwrap(); + let config = McpConfig { + servers: vec![McpServerConfig { + name: "mock_server".to_string(), + transport: McpTransport::Streamable { + url: mock_server.url(), + token: None, + }, + }], + }; - let result = mcp_server.call_tool("unknown_tool", json!({})).await; + let mut manager = McpClientManager::new(config).await.unwrap(); + + // Try to call unknown tool + let result = manager + .call_tool("unknown_tool", Some(serde_json::Map::new())) + .await; assert!(result.is_err(), "Should fail for unknown tool"); - let session = mcp_server - .get_tool_session("brave_web_search") - .await - .unwrap(); - let session_result = session.call_tool("unknown_tool", json!({})).await; - assert!( - session_result.is_err(), - "Session should fail for unknown tool" - ); + match result.unwrap_err() { + McpError::ToolNotFound(name) => { + assert_eq!(name, "unknown_tool"); + } + _ => panic!("Expected ToolNotFound error"), + } + + manager.shutdown().await; } #[tokio::test] async fn test_connection_without_server() { - let mut server = MCPToolServer::new(); + let config = McpConfig { + servers: vec![McpServerConfig { + name: "nonexistent".to_string(), + transport: McpTransport::Streamable { + url: "http://localhost:9999/mcp".to_string(), + token: None, + }, + }], + }; - let result = server - .add_tool_server("http://localhost:9999/mcp".to_string()) - .await; + let result = McpClientManager::new(config).await; assert!(result.is_err(), "Should fail when no server is running"); - let error_msg = result.unwrap_err().to_string(); - assert!( - error_msg.contains("Failed to connect") || error_msg.contains("Connection"), - "Error should be connection-related: {}", - error_msg - ); + if let Err(e) = result { + let error_msg = e.to_string(); + assert!( + error_msg.contains("Failed to connect") || error_msg.contains("Connection"), + "Error should be connection-related: {}", + error_msg + ); + } } -// Schema Adaptation Tests +// Schema Validation Tests #[tokio::test] -async fn test_schema_validation() { +async fn test_tool_info_structure() { let mock_server = create_mock_server().await; - let mut mcp_server = MCPToolServer::new(); - mcp_server.add_tool_server(mock_server.url()).await.unwrap(); + let config = McpConfig { + servers: vec![McpServerConfig { + name: "mock_server".to_string(), + transport: McpTransport::Streamable { + url: mock_server.url(), + token: None, + }, + }], + }; - let description = mcp_server.get_tool_description("brave_web_search"); - assert!(description.is_some(), "Should have tool description"); + let manager = McpClientManager::new(config).await.unwrap(); - let desc_value = description.unwrap(); - assert!(desc_value.get("name").is_some()); - assert!(desc_value.get("description").is_some()); + let tools = manager.list_tools(); + let brave_search = tools + .iter() + .find(|t| t.name == "brave_web_search") + .expect("Should have brave_web_search tool"); + + assert_eq!(brave_search.name, "brave_web_search"); + assert!(brave_search.description.contains("Mock web search")); + assert_eq!(brave_search.server, "mock_server"); + assert!(brave_search.parameters.is_some()); } -// SSE Parsing Tests +// SSE Parsing Tests (simplified since we don't expose parse_sse_event) #[tokio::test] -async fn test_sse_event_parsing_success() { - let valid_event = "data: {\"jsonrpc\": \"2.0\", \"id\": \"1\", \"result\": {\"test\": \"success\", \"content\": [{\"type\": \"text\", \"text\": \"Hello\"}]}}"; +async fn test_sse_connection() { + let mock_server = create_mock_server().await; - let result = parse_sse_event(valid_event); - assert!(result.is_ok(), "Valid SSE event should parse successfully"); + // Test SSE transport configuration + let config = McpConfig { + servers: vec![McpServerConfig { + name: "sse_server".to_string(), + transport: McpTransport::Sse { + // Mock server doesn't support SSE, but we can test the config + url: format!("http://127.0.0.1:{}/sse", mock_server.port), + token: Some("test_token".to_string()), + }, + }], + }; - let parsed = result.unwrap(); - assert!(parsed.is_some(), "Should return parsed data"); - - let response = parsed.unwrap(); - assert_eq!(response["test"], "success"); - assert!(response.get("content").is_some()); -} - -#[tokio::test] -async fn test_sse_event_parsing_error() { - let error_event = "data: {\"jsonrpc\": \"2.0\", \"id\": \"1\", \"error\": {\"code\": -1, \"message\": \"Rate limit exceeded\"}}"; - - let result = parse_sse_event(error_event); - assert!(result.is_err(), "Error SSE event should return error"); - - let error_msg = result.unwrap_err().to_string(); - assert!( - error_msg.contains("Rate limit exceeded"), - "Should contain error message" - ); -} - -#[tokio::test] -async fn test_sse_event_parsing_empty() { - let empty_event = ""; - let result = parse_sse_event(empty_event); - assert!(result.is_ok(), "Empty event should parse successfully"); - assert!(result.unwrap().is_none(), "Empty event should return None"); - - let no_data_event = "event: ping\nid: 123"; - let result2 = parse_sse_event(no_data_event); - assert!(result2.is_ok(), "Non-data event should parse successfully"); - assert!( - result2.unwrap().is_none(), - "Non-data event should return None" - ); + // This will fail to connect but tests the configuration + let result = McpClientManager::new(config).await; + assert!(result.is_err(), "Mock server doesn't support SSE"); } // Connection Type Tests #[tokio::test] -async fn test_connection_type_detection() { - let mock_server = create_mock_server().await; +async fn test_transport_types() { + // Test different transport configurations - let session_result = ToolSession::new(mock_server.url()).await; - assert!(session_result.is_ok(), "Should create HTTP session"); + // HTTP/Streamable transport + let http_config = McpServerConfig { + name: "http_server".to_string(), + transport: McpTransport::Streamable { + url: "http://localhost:8080/mcp".to_string(), + token: Some("auth_token".to_string()), + }, + }; + assert_eq!(http_config.name, "http_server"); - let session = session_result.unwrap(); - assert!(session.connection_info().contains("HTTP")); - assert!(session.is_ready(), "HTTP session should be ready"); + // SSE transport + let sse_config = McpServerConfig { + name: "sse_server".to_string(), + transport: McpTransport::Sse { + url: "http://localhost:8081/sse".to_string(), + token: None, + }, + }; + assert_eq!(sse_config.name, "sse_server"); - // Stdio sessions are no longer supported - test invalid URL handling - let invalid_session = ToolSession::new("invalid-url".to_string()).await; - assert!(invalid_session.is_err(), "Should reject non-HTTP URLs"); + // STDIO transport + let stdio_config = McpServerConfig { + name: "stdio_server".to_string(), + transport: McpTransport::Stdio { + command: "mcp-server".to_string(), + args: vec!["--port".to_string(), "8082".to_string()], + envs: HashMap::new(), + }, + }; + assert_eq!(stdio_config.name, "stdio_server"); } // Integration Pattern Tests #[tokio::test] -async fn test_responses_api_integration_patterns() { +async fn test_complete_workflow() { let mock_server = create_mock_server().await; - // Server initialization - let mut mcp_server = MCPToolServer::new(); + // 1. Initialize configuration + let config = McpConfig { + servers: vec![McpServerConfig { + name: "integration_test".to_string(), + transport: McpTransport::Streamable { + url: mock_server.url(), + token: None, + }, + }], + }; - // Tool server connection (like responses API startup) - match mcp_server.add_tool_server(mock_server.url()).await { - Ok(_) => { - let stats = mcp_server.get_tool_stats(); - assert_eq!(stats.total_tools, 2); - assert_eq!(stats.total_servers, 1); - } - Err(e) => { - panic!("Should connect to mock server: {}", e); - } - } + // 2. Connect to server + let mut manager = McpClientManager::new(config) + .await + .expect("Should connect to mock server"); - // Tool availability checking - let test_tools = vec!["brave_web_search", "brave_local_search", "calculator"]; - for tool in &test_tools { - let _available = mcp_server.has_tool(tool); - } + // 3. Verify server connection + let servers = manager.list_servers(); + assert_eq!(servers.len(), 1); + assert_eq!(servers[0], "integration_test"); - // Tool session creation - if mcp_server.has_tool("brave_web_search") { - let session_result = mcp_server.get_tool_session("brave_web_search").await; - assert!(session_result.is_ok(), "Should create tool session"); - } + // 4. Check available tools + let tools = manager.list_tools(); + assert_eq!(tools.len(), 2); - // Multi-tool session creation - let available_tools = mcp_server.list_tools(); - if !available_tools.is_empty() { - let session_manager_result = mcp_server.create_multi_tool_session(available_tools).await; - assert!( - session_manager_result.is_ok(), - "Should create multi-tool session" - ); - } + // 5. Verify specific tools exist + assert!(manager.has_tool("brave_web_search")); + assert!(manager.has_tool("brave_local_search")); + assert!(!manager.has_tool("nonexistent_tool")); - // Tool execution - let result = mcp_server + // 6. Execute a tool + let result = manager .call_tool( "brave_web_search", - json!({ - "query": "SGLang router MCP integration", - "count": 1 - }), + Some( + json!({ + "query": "SGLang router MCP integration", + "count": 1 + }) + .as_object() + .unwrap() + .clone(), + ), ) .await; - if result.is_err() { - // This might fail if called after another test that uses the same tool name - // Due to the shared mock server. That's OK, the main test covers this. - return; - } - assert!(result.is_ok(), "Should execute tool successfully"); -} -// Complete Integration Test + assert!(result.is_ok(), "Tool execution should succeed"); + let response = result.unwrap(); + assert!(!response.content.is_empty(), "Should return content"); -#[tokio::test] -async fn test_responses_api_integration() { - let mock_server = create_mock_server().await; - - // Run through all functionality required for responses API integration - let mut mcp_server = MCPToolServer::new(); - mcp_server.add_tool_server(mock_server.url()).await.unwrap(); - - // Test all core functionality - assert!(mcp_server.has_tool("brave_web_search")); - - let session = mcp_server - .get_tool_session("brave_web_search") - .await - .unwrap(); - assert!(session.is_ready()); - - let session_manager = mcp_server - .create_multi_tool_session(mcp_server.list_tools()) - .await - .unwrap(); - assert!(session_manager.session_stats().total_sessions > 0); - - let result = mcp_server - .call_tool( - "brave_web_search", - json!({ - "query": "test", - "count": 1 - }), - ) - .await - .unwrap(); - assert!(result.get("content").is_some()); + // 7. Clean shutdown + manager.shutdown().await; // Verify all required capabilities for responses API integration let capabilities = [ "MCP server initialization", "Tool server connection and discovery", "Tool availability checking", - "Individual tool session management", - "Multi-tool session manager (Python tool_session_ctxs pattern)", - "Concurrent tool execution", - "Direct tool execution", + "Tool execution", "Error handling and robustness", - "Protocol compliance (SSE parsing)", - "Schema adaptation (Python parity)", + "Multi-server support", + "Schema adaptation", "Mock server integration (no external dependencies)", ]; - assert_eq!(capabilities.len(), 11); + assert_eq!(capabilities.len(), 8); }