[router] move to mcp sdk instead (#10057)
This commit is contained in:
@@ -1,9 +1,14 @@
|
||||
// tests/common/mock_mcp_server.rs - Mock MCP server for testing
|
||||
|
||||
use axum::{
|
||||
extract::Json, http::StatusCode, response::Json as ResponseJson, routing::post, Router,
|
||||
use rmcp::{
|
||||
handler::server::{router::tool::ToolRouter, wrapper::Parameters},
|
||||
model::*,
|
||||
service::RequestContext,
|
||||
tool, tool_handler, tool_router,
|
||||
transport::streamable_http_server::{
|
||||
session::local::LocalSessionManager, StreamableHttpService,
|
||||
},
|
||||
ErrorData as McpError, RoleServer, ServerHandler,
|
||||
};
|
||||
use serde_json::{json, Value};
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
/// Mock MCP server that returns hardcoded responses for testing
|
||||
@@ -12,6 +17,69 @@ pub struct MockMCPServer {
|
||||
pub server_handle: Option<tokio::task::JoinHandle<()>>,
|
||||
}
|
||||
|
||||
/// Simple test server with mock search tools
|
||||
#[derive(Clone)]
|
||||
pub struct MockSearchServer {
|
||||
tool_router: ToolRouter<MockSearchServer>,
|
||||
}
|
||||
|
||||
#[tool_router]
|
||||
impl MockSearchServer {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
tool_router: Self::tool_router(),
|
||||
}
|
||||
}
|
||||
|
||||
#[tool(description = "Mock web search tool")]
|
||||
fn brave_web_search(
|
||||
&self,
|
||||
Parameters(params): Parameters<serde_json::Map<String, serde_json::Value>>,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
let query = params
|
||||
.get("query")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("test");
|
||||
Ok(CallToolResult::success(vec![Content::text(format!(
|
||||
"Mock search results for: {}",
|
||||
query
|
||||
))]))
|
||||
}
|
||||
|
||||
#[tool(description = "Mock local search tool")]
|
||||
fn brave_local_search(
|
||||
&self,
|
||||
Parameters(_params): Parameters<serde_json::Map<String, serde_json::Value>>,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
Ok(CallToolResult::success(vec![Content::text(
|
||||
"Mock local search results",
|
||||
)]))
|
||||
}
|
||||
}
|
||||
|
||||
#[tool_handler]
|
||||
impl ServerHandler for MockSearchServer {
|
||||
fn get_info(&self) -> ServerInfo {
|
||||
ServerInfo {
|
||||
protocol_version: ProtocolVersion::V_2024_11_05,
|
||||
capabilities: ServerCapabilities::builder().enable_tools().build(),
|
||||
server_info: Implementation {
|
||||
name: "Mock MCP Server".to_string(),
|
||||
version: "1.0.0".to_string(),
|
||||
},
|
||||
instructions: Some("Mock server for testing".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
async fn initialize(
|
||||
&self,
|
||||
_request: InitializeRequestParam,
|
||||
_context: RequestContext<RoleServer>,
|
||||
) -> Result<InitializeResult, McpError> {
|
||||
Ok(self.get_info())
|
||||
}
|
||||
}
|
||||
|
||||
impl MockMCPServer {
|
||||
/// Start a mock MCP server on an available port
|
||||
pub async fn start() -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
|
||||
@@ -19,7 +87,14 @@ impl MockMCPServer {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await?;
|
||||
let port = listener.local_addr()?.port();
|
||||
|
||||
let app = Router::new().route("/mcp", post(handle_mcp_request));
|
||||
// Create the MCP service using rmcp's StreamableHttpService
|
||||
let service = StreamableHttpService::new(
|
||||
|| Ok(MockSearchServer::new()),
|
||||
LocalSessionManager::default().into(),
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
let app = axum::Router::new().nest_service("/mcp", service);
|
||||
|
||||
let server_handle = tokio::spawn(async move {
|
||||
axum::serve(listener, app)
|
||||
@@ -59,142 +134,10 @@ impl Drop for MockMCPServer {
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle MCP requests and return mock responses
|
||||
async fn handle_mcp_request(Json(request): Json<Value>) -> Result<ResponseJson<Value>, StatusCode> {
|
||||
// Parse the JSON-RPC request
|
||||
let method = request.get("method").and_then(|m| m.as_str()).unwrap_or("");
|
||||
|
||||
let id = request
|
||||
.get("id")
|
||||
.and_then(|i| i.as_str())
|
||||
.unwrap_or("unknown");
|
||||
|
||||
let response = match method {
|
||||
"initialize" => {
|
||||
// Mock initialize response
|
||||
json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"result": {
|
||||
"serverInfo": {
|
||||
"name": "Mock MCP Server",
|
||||
"version": "1.0.0"
|
||||
},
|
||||
"instructions": "Mock server for testing"
|
||||
}
|
||||
})
|
||||
}
|
||||
"tools/list" => {
|
||||
// Mock tools list response
|
||||
json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"result": {
|
||||
"tools": [
|
||||
{
|
||||
"name": "brave_web_search",
|
||||
"description": "Mock web search tool",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
"count": {"type": "integer"}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "brave_local_search",
|
||||
"description": "Mock local search tool",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
})
|
||||
}
|
||||
"tools/call" => {
|
||||
// Mock tool call response
|
||||
let empty_json = json!({});
|
||||
let params = request.get("params").unwrap_or(&empty_json);
|
||||
let tool_name = params.get("name").and_then(|n| n.as_str()).unwrap_or("");
|
||||
let empty_args = json!({});
|
||||
let arguments = params.get("arguments").unwrap_or(&empty_args);
|
||||
|
||||
match tool_name {
|
||||
"brave_web_search" => {
|
||||
let query = arguments
|
||||
.get("query")
|
||||
.and_then(|q| q.as_str())
|
||||
.unwrap_or("test");
|
||||
json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"result": {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": format!("Mock search results for: {}", query)
|
||||
}
|
||||
],
|
||||
"isError": false
|
||||
}
|
||||
})
|
||||
}
|
||||
"brave_local_search" => {
|
||||
json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"result": {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Mock local search results"
|
||||
}
|
||||
],
|
||||
"isError": false
|
||||
}
|
||||
})
|
||||
}
|
||||
_ => {
|
||||
// Unknown tool
|
||||
json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"error": {
|
||||
"code": -1,
|
||||
"message": format!("Unknown tool: {}", tool_name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// Unknown method
|
||||
json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"error": {
|
||||
"code": -32601,
|
||||
"message": format!("Method not found: {}", method)
|
||||
}
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
Ok(ResponseJson(response))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[allow(unused_imports)]
|
||||
mod tests {
|
||||
#[allow(unused_imports)]
|
||||
use super::MockMCPServer;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mock_server_startup() {
|
||||
@@ -205,32 +148,32 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mock_server_responses() {
|
||||
async fn test_mock_server_with_rmcp_client() {
|
||||
let mut server = MockMCPServer::start().await.unwrap();
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
// Test initialize
|
||||
let init_request = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": "1",
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {}
|
||||
// Test that we can connect with rmcp client
|
||||
use rmcp::transport::StreamableHttpClientTransport;
|
||||
use rmcp::ServiceExt;
|
||||
|
||||
let transport = StreamableHttpClientTransport::from_uri(server.url().as_str());
|
||||
let client = ().serve(transport).await;
|
||||
|
||||
assert!(client.is_ok(), "Should be able to connect to mock server");
|
||||
|
||||
if let Ok(client) = client {
|
||||
// Test listing tools
|
||||
let tools = client.peer().list_all_tools().await;
|
||||
assert!(tools.is_ok(), "Should be able to list tools");
|
||||
|
||||
if let Ok(tools) = tools {
|
||||
assert_eq!(tools.len(), 2, "Should have 2 tools");
|
||||
assert!(tools.iter().any(|t| t.name == "brave_web_search"));
|
||||
assert!(tools.iter().any(|t| t.name == "brave_local_search"));
|
||||
}
|
||||
});
|
||||
|
||||
let response = client
|
||||
.post(server.url())
|
||||
.json(&init_request)
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(response.status().is_success());
|
||||
let json: Value = response.json().await.unwrap();
|
||||
assert_eq!(json["jsonrpc"], "2.0");
|
||||
assert_eq!(json["result"]["serverInfo"]["name"], "Mock MCP Server");
|
||||
// Shutdown by dropping the client
|
||||
drop(client);
|
||||
}
|
||||
|
||||
server.stop().await;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user