diff --git a/sgl-router/src/protocols/spec.rs b/sgl-router/src/protocols/spec.rs index 49ca7ceaf..672876c6c 100644 --- a/sgl-router/src/protocols/spec.rs +++ b/sgl-router/src/protocols/spec.rs @@ -683,6 +683,33 @@ pub struct CompletionStreamChoice { pub struct ResponseTool { #[serde(rename = "type")] pub r#type: ResponseToolType, + // MCP-specific fields (used when type == "mcp") + #[serde(skip_serializing_if = "Option::is_none")] + pub server_url: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub authorization: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub server_label: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub server_description: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub require_approval: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub allowed_tools: Option>, +} + +impl Default for ResponseTool { + fn default() -> Self { + Self { + r#type: ResponseToolType::WebSearchPreview, + server_url: None, + authorization: None, + server_label: None, + server_description: None, + require_approval: None, + allowed_tools: None, + } + } } #[derive(Debug, Clone, Deserialize, Serialize)] @@ -690,6 +717,7 @@ pub struct ResponseTool { pub enum ResponseToolType { WebSearchPreview, CodeInterpreter, + Mcp, } #[derive(Debug, Clone, Deserialize, Serialize)] diff --git a/sgl-router/src/routers/http/openai_router.rs b/sgl-router/src/routers/http/openai_router.rs index 035af37dc..f858cccd5 100644 --- a/sgl-router/src/routers/http/openai_router.rs +++ b/sgl-router/src/routers/http/openai_router.rs @@ -6,8 +6,8 @@ use crate::data_connector::{ResponseId, SharedResponseStorage, StoredResponse}; use crate::protocols::spec::{ ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, ResponseContentPart, ResponseInput, ResponseInputOutputItem, ResponseOutputItem, - ResponseStatus, ResponseTextFormat, ResponsesGetParams, ResponsesRequest, ResponsesResponse, - TextFormatType, + ResponseStatus, ResponseTextFormat, ResponseTool, ResponseToolType, ResponsesGetParams, + ResponsesRequest, ResponsesResponse, TextFormatType, }; use crate::routers::header_utils::{apply_request_headers, preserve_response_headers}; use async_trait::async_trait; @@ -20,7 +20,14 @@ use axum::{ use bytes::Bytes; use futures_util::StreamExt; use serde_json::{json, to_value, Value}; -use std::{any::Any, borrow::Cow, collections::HashMap, io, sync::atomic::AtomicBool}; +use std::{ + any::Any, + borrow::Cow, + collections::HashMap, + io, + sync::{atomic::AtomicBool, Arc}, + time::SystemTime, +}; use tokio::sync::mpsc; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{error, info, warn}; @@ -37,6 +44,8 @@ pub struct OpenAIRouter { healthy: AtomicBool, /// Response storage for managing conversation history response_storage: SharedResponseStorage, + /// Optional MCP manager (enabled via config presence) + mcp_manager: Option>, } impl std::fmt::Debug for OpenAIRouter { @@ -210,12 +219,33 @@ impl OpenAIRouter { let circuit_breaker = CircuitBreaker::with_config(core_cb_config); + // Optional MCP manager activation via env var path (config-driven gate) + let mcp_manager = match std::env::var("SGLANG_MCP_CONFIG").ok() { + Some(path) if !path.trim().is_empty() => { + match crate::mcp::McpConfig::from_file(&path).await { + Ok(cfg) => match crate::mcp::McpClientManager::new(cfg).await { + Ok(mgr) => Some(Arc::new(mgr)), + Err(err) => { + warn!("Failed to initialize MCP manager: {}", err); + None + } + }, + Err(err) => { + warn!("Failed to load MCP config from '{}': {}", path, err); + None + } + } + } + _ => None, + }; + Ok(Self { client, base_url, circuit_breaker, healthy: AtomicBool::new(true), response_storage, + mcp_manager, }) } @@ -223,10 +253,79 @@ impl OpenAIRouter { &self, url: String, headers: Option<&HeaderMap>, - payload: Value, + mut payload: Value, original_body: &ResponsesRequest, original_previous_response_id: Option, ) -> Response { + // Request-scoped MCP: build from request tools if provided; otherwise fall back to router-level MCP + let req_mcp_manager = Self::mcp_manager_from_request_tools(&original_body.tools).await; + let active_mcp = req_mcp_manager.as_ref().or(self.mcp_manager.as_ref()); + + // If the client requested MCP but we couldn't initialize it, fail early with a clear error + let requested_mcp = original_body + .tools + .iter() + .any(|t| matches!(t.r#type, ResponseToolType::Mcp)); + if requested_mcp && active_mcp.is_none() { + return ( + StatusCode::BAD_GATEWAY, + json!({ + "error": { + "message": "MCP server unavailable or failed to initialize from request tools", + "type": "mcp_unavailable", + "param": "tools", + } + }) + .to_string(), + ) + .into_response(); + } + + // If MCP is active, mirror one function tool into the outgoing payload + if let Some(mcp) = active_mcp { + if let Some(obj) = payload.as_object_mut() { + // Remove any non-function tools (e.g., custom "mcp" items) from outgoing payload + if let Some(v) = obj.get_mut("tools") { + if let Some(arr) = v.as_array_mut() { + arr.retain(|item| { + item.get("type") + .and_then(|v| v.as_str()) + .map(|s| s == "function") + .unwrap_or(false) + }); + if arr.is_empty() { + obj.remove("tools"); + obj.insert( + "tool_choice".to_string(), + Value::String("none".to_string()), + ); + } + } + } + // Build function tools for all discovered MCP tools + let mut tools_json = Vec::new(); + let tools = mcp.list_tools(); + for t in tools { + let parameters = t.parameters.clone().unwrap_or(serde_json::json!({ + "type": "object", + "properties": {}, + "additionalProperties": false + })); + let tool = serde_json::json!({ + "type": "function", + "name": t.name, + "description": t.description, + "parameters": parameters + }); + tools_json.push(tool); + } + if !tools_json.is_empty() { + obj.insert("tools".to_string(), Value::Array(tools_json)); + // Ensure tool_choice auto to allow model planning + obj.insert("tool_choice".to_string(), Value::String("auto".to_string())); + } + } + } let request_builder = self.client.post(&url).json(&payload); // Apply headers with filtering @@ -239,7 +338,6 @@ impl OpenAIRouter { match request_builder.send().await { Ok(response) => { let status = response.status(); - if !status.is_success() { let error_text = response .text() @@ -290,16 +388,124 @@ impl OpenAIRouter { obj.insert("store".to_string(), Value::Bool(original_body.store)); } + let mut final_response_json = openai_response_json; + + if let Some(mcp) = active_mcp { + if let Some((call_id, tool_name, args_json_str)) = + Self::extract_function_call(&final_response_json) + { + info!( + "Detected function call: name={}, call_id={}, args={}", + tool_name, call_id, args_json_str + ); + + let call_started = SystemTime::now(); + let call_result = + Self::execute_mcp_call(mcp, &tool_name, &args_json_str).await; + let call_duration_ms = + call_started.elapsed().unwrap_or_default().as_millis(); + + let (output_payload, call_ok, call_error) = match call_result { + Ok((server, out)) => { + info!( + call_id = %call_id, + tool_name = %tool_name, + server = %server, + duration_ms = call_duration_ms, + "MCP tool call succeeded" + ); + (out, true, None) + } + Err(err) => { + warn!( + call_id = %call_id, + tool_name = %tool_name, + duration_ms = call_duration_ms, + error = %err, + "MCP tool call failed" + ); + ( + serde_json::json!({ + "error": err + }) + .to_string(), + false, + Some(err), + ) + } + }; + + match self + .resume_with_tool_result(ResumeWithToolArgs { + url: &url, + headers, + original_payload: &payload, + call_id: &call_id, + tool_name: &tool_name, + args_json_str: &args_json_str, + output_str: &output_payload, + original_body, + }) + .await + { + Ok(mut resumed_json) => { + if !call_ok { + if let Some(obj) = resumed_json.as_object_mut() { + let metadata_value = + obj.entry("metadata").or_insert_with(|| { + Value::Object(serde_json::Map::new()) + }); + if let Some(metadata) = + metadata_value.as_object_mut() + { + if let Some(err_msg) = call_error.as_ref() { + metadata.insert( + "mcp_error".to_string(), + Value::String(err_msg.clone()), + ); + } + } + } + } + final_response_json = resumed_json; + } + Err(err) => { + warn!("Failed to resume with tool result: {}", err); + let error_body = json!({ + "error": { + "message": format!( + "Failed to resume with tool result: {}", + err + ), + "type": "internal_error", + } + }) + .to_string(); + + return ( + StatusCode::INTERNAL_SERVER_ERROR, + [("content-type", "application/json")], + error_body, + ) + .into_response(); + } + } + } else { + info!("No function call found in upstream response; skipping MCP"); + } + } + + Self::mask_tools_as_mcp(&mut final_response_json, original_body); if original_body.store { if let Err(e) = self - .store_response_internal(&openai_response_json, original_body) + .store_response_internal(&final_response_json, original_body) .await { warn!("Failed to store response: {}", e); } } - match serde_json::to_string(&openai_response_json) { + match serde_json::to_string(&final_response_json) { Ok(json_str) => ( StatusCode::OK, [("content-type", "application/json")], @@ -334,6 +540,49 @@ impl OpenAIRouter { } } + /// Build a request-scoped MCP manager from request tools, if present. + async fn mcp_manager_from_request_tools( + tools: &[ResponseTool], + ) -> Option> { + let tool = tools + .iter() + .find(|t| matches!(t.r#type, ResponseToolType::Mcp) && t.server_url.is_some())?; + let server_url = tool.server_url.as_ref()?.trim().to_string(); + if !(server_url.starts_with("http://") || server_url.starts_with("https://")) { + warn!( + "Ignoring MCP server_url with unsupported scheme: {}", + server_url + ); + return None; + } + let name = tool + .server_label + .clone() + .unwrap_or_else(|| "request-mcp".to_string()); + let token = tool.authorization.clone(); + let transport = if server_url.contains("/sse") { + crate::mcp::McpTransport::Sse { + url: server_url, + token, + } + } else { + crate::mcp::McpTransport::Streamable { + url: server_url, + token, + } + }; + let cfg = crate::mcp::McpConfig { + servers: vec![crate::mcp::McpServerConfig { name, transport }], + }; + match crate::mcp::McpClientManager::new(cfg).await { + Ok(mgr) => Some(Arc::new(mgr)), + Err(err) => { + warn!("Failed to initialize request-scoped MCP manager: {}", err); + None + } + } + } + async fn handle_streaming_response( &self, url: String, @@ -765,6 +1014,180 @@ impl OpenAIRouter { } } +struct ResumeWithToolArgs<'a> { + url: &'a str, + headers: Option<&'a HeaderMap>, + original_payload: &'a Value, + call_id: &'a str, + tool_name: &'a str, + args_json_str: &'a str, + output_str: &'a str, + original_body: &'a ResponsesRequest, +} + +impl OpenAIRouter { + fn extract_function_call(resp: &Value) -> Option<(String, String, String)> { + let output = resp.get("output")?.as_array()?; + for item in output { + let obj = item.as_object()?; + let t = obj.get("type")?.as_str()?; + if t == "function_tool_call" || t == "function_call" { + let call_id = obj + .get("call_id") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + .or_else(|| { + obj.get("id") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + })?; + let name = obj.get("name")?.as_str()?.to_string(); + let arguments = obj.get("arguments")?.as_str()?.to_string(); + return Some((call_id, name, arguments)); + } + } + None + } + + /// Replace returned tools with the original request's MCP tool block (if present) so + /// external clients see MCP semantics rather than internal function tools. + fn mask_tools_as_mcp(resp: &mut Value, original_body: &ResponsesRequest) { + let mcp_tool = original_body + .tools + .iter() + .find(|t| matches!(t.r#type, ResponseToolType::Mcp) && t.server_url.is_some()); + let Some(t) = mcp_tool else { + return; + }; + + let mut m = serde_json::Map::new(); + m.insert("type".to_string(), Value::String("mcp".to_string())); + if let Some(label) = &t.server_label { + m.insert("server_label".to_string(), Value::String(label.clone())); + } + if let Some(url) = &t.server_url { + m.insert("server_url".to_string(), Value::String(url.clone())); + } + if let Some(desc) = &t.server_description { + m.insert( + "server_description".to_string(), + Value::String(desc.clone()), + ); + } + if let Some(req) = &t.require_approval { + m.insert("require_approval".to_string(), Value::String(req.clone())); + } + if let Some(allowed) = &t.allowed_tools { + m.insert( + "allowed_tools".to_string(), + Value::Array(allowed.iter().map(|s| Value::String(s.clone())).collect()), + ); + } + + if let Some(obj) = resp.as_object_mut() { + obj.insert("tools".to_string(), Value::Array(vec![Value::Object(m)])); + obj.entry("tool_choice") + .or_insert(Value::String("auto".to_string())); + } + } + + async fn execute_mcp_call( + mcp_mgr: &Arc, + tool_name: &str, + args_json_str: &str, + ) -> Result<(String, String), String> { + let args_value: Value = + serde_json::from_str(args_json_str).map_err(|e| format!("parse tool args: {}", e))?; + let args_obj = args_value.as_object().cloned(); + + let server_name = mcp_mgr + .get_tool(tool_name) + .map(|t| t.server) + .ok_or_else(|| format!("tool not found: {}", tool_name))?; + + let result = mcp_mgr + .call_tool(tool_name, args_obj) + .await + .map_err(|e| format!("tool call failed: {}", e))?; + + let output_str = serde_json::to_string(&result) + .map_err(|e| format!("Failed to serialize tool result: {}", e))?; + Ok((server_name, output_str)) + } + + async fn resume_with_tool_result(&self, args: ResumeWithToolArgs<'_>) -> Result { + let mut payload2 = args.original_payload.clone(); + let obj = payload2 + .as_object_mut() + .ok_or_else(|| "payload not an object".to_string())?; + + // Build function_call and tool result items per OpenAI Responses spec + let user_item = serde_json::json!({ + "type": "message", + "role": "user", + "content": args.original_body.input.clone() + }); + // temp system message since currently only support 1 turn of mcp function call + let system_item = serde_json::json!({ + "type": "message", + "role": "system", + "content": "please resume with the following tool result, and answer user's question directly, don't trigger any more tool calls" + }); + + let func_item = serde_json::json!({ + "type": "function_call", + "call_id": args.call_id, + "name": args.tool_name, + "arguments": args.args_json_str + }); + // Build tool result item as function_call_output per OpenAI Responses spec + let tool_item = serde_json::json!({ + "type": "function_call_output", + "call_id": args.call_id, + "output": args.output_str + }); + + obj.insert( + "input".to_string(), + Value::Array(vec![user_item, system_item, func_item, tool_item]), + ); + + // Ensure non-streaming and no store to upstream + obj.insert("stream".to_string(), Value::Bool(false)); + obj.insert("store".to_string(), Value::Bool(false)); + let mut req = self.client.post(args.url).json(&payload2); + if let Some(headers) = args.headers { + req = apply_request_headers(headers, req, true); + } + let resp = req + .send() + .await + .map_err(|e| format!("resume request failed: {}", e))?; + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + return Err(format!("resume upstream error {}: {}", status, body)); + } + let mut v = resp + .json::() + .await + .map_err(|e| format!("parse resume response: {}", e))?; + + if let Some(instr) = &args.original_body.instructions { + if let Some(obj) = v.as_object_mut() { + obj.entry("instructions") + .or_insert(Value::String(instr.clone())); + } + } + // After resume, mask tools as MCP if request used MCP + Self::mask_tools_as_mcp(&mut v, args.original_body); + if let Some(obj) = v.as_object_mut() { + obj.insert("store".to_string(), Value::Bool(args.original_body.store)); + } + Ok(v) + } +} + #[async_trait] impl super::super::RouterTrait for OpenAIRouter { fn as_any(&self) -> &dyn Any { diff --git a/sgl-router/tests/common/mock_worker.rs b/sgl-router/tests/common/mock_worker.rs index f46e6b854..369ab55c2 100755 --- a/sgl-router/tests/common/mock_worker.rs +++ b/sgl-router/tests/common/mock_worker.rs @@ -644,27 +644,96 @@ async fn responses_handler( })) .into_response() } else { - Json(json!({ - "id": format!("resp-{}", Uuid::new_v4()), - "object": "response", - "created_at": timestamp, - "model": "mock-model", - "output": [{ - "type": "message", - "role": "assistant", - "content": [{ - "type": "output_text", - "text": "This is a mock responses output." - }] - }], - "status": "completed", - "usage": { - "input_tokens": 10, - "output_tokens": 5, - "total_tokens": 15 - } - })) - .into_response() + // If tools are provided and this is the first call (no previous_response_id), + // emit a single function_tool_call to trigger the router's MCP flow. + let has_tools = payload + .get("tools") + .and_then(|v| v.as_array()) + .map(|arr| { + arr.iter().any(|tool| { + tool.get("type") + .and_then(|t| t.as_str()) + .map(|t| t == "function") + .unwrap_or(false) + }) + }) + .unwrap_or(false); + let has_function_output = payload + .get("input") + .and_then(|v| v.as_array()) + .map(|items| { + items.iter().any(|item| { + item.get("type") + .and_then(|t| t.as_str()) + .map(|t| t == "function_call_output") + .unwrap_or(false) + }) + }) + .unwrap_or(false); + + if has_tools && !has_function_output { + let rid = format!("resp-{}", Uuid::new_v4()); + Json(json!({ + "id": rid, + "object": "response", + "created_at": timestamp, + "model": "mock-model", + "output": [{ + "type": "function_tool_call", + "id": "call_1", + "name": "brave_web_search", + "arguments": "{\"query\":\"SGLang router MCP integration\"}", + "status": "in_progress" + }], + "status": "in_progress", + "usage": null + })) + .into_response() + } else if has_tools && has_function_output { + Json(json!({ + "id": format!("resp-{}", Uuid::new_v4()), + "object": "response", + "created_at": timestamp, + "model": "mock-model", + "output": [{ + "type": "message", + "role": "assistant", + "content": [{ + "type": "output_text", + "text": "Tool result consumed; here is the final answer." + }] + }], + "status": "completed", + "usage": { + "input_tokens": 12, + "output_tokens": 7, + "total_tokens": 19 + } + })) + .into_response() + } else { + Json(json!({ + "id": format!("resp-{}", Uuid::new_v4()), + "object": "response", + "created_at": timestamp, + "model": "mock-model", + "output": [{ + "type": "message", + "role": "assistant", + "content": [{ + "type": "output_text", + "text": "This is a mock responses output." + }] + }], + "status": "completed", + "usage": { + "input_tokens": 10, + "output_tokens": 5, + "total_tokens": 15 + } + })) + .into_response() + } } } diff --git a/sgl-router/tests/responses_api_test.rs b/sgl-router/tests/responses_api_test.rs index e1ef4380f..a5950c805 100644 --- a/sgl-router/tests/responses_api_test.rs +++ b/sgl-router/tests/responses_api_test.rs @@ -6,6 +6,186 @@ use sglang_router_rs::protocols::spec::{ ToolChoiceValue, Truncation, UsageInfo, }; +mod common; +use common::mock_mcp_server::MockMCPServer; +use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType}; +use sglang_router_rs::config::{ + CircuitBreakerConfig, ConnectionMode, HealthCheckConfig, PolicyConfig, RetryConfig, + RouterConfig, RoutingMode, +}; +use sglang_router_rs::routers::RouterFactory; +use sglang_router_rs::server::AppContext; +use std::sync::Arc; + +#[tokio::test] +async fn test_non_streaming_mcp_minimal_e2e_with_persistence() { + // Start mock MCP server + let mut mcp = MockMCPServer::start().await.expect("start mcp"); + + // Write a temp MCP config file + let mcp_yaml = format!( + "servers:\n - name: mock\n protocol: streamable\n url: {}\n", + mcp.url() + ); + let dir = tempfile::tempdir().expect("tmpdir"); + let cfg_path = dir.path().join("mcp.yaml"); + std::fs::write(&cfg_path, mcp_yaml).expect("write mcp cfg"); + + // Start mock OpenAI worker + let mut worker = MockWorker::new(MockWorkerConfig { + port: 0, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }); + let worker_url = worker.start().await.expect("start worker"); + + // Build router config (HTTP OpenAI mode) + let router_cfg = RouterConfig { + mode: RoutingMode::OpenAI { + worker_urls: vec![worker_url], + }, + connection_mode: ConnectionMode::Http, + policy: PolicyConfig::Random, + host: "127.0.0.1".to_string(), + port: 0, + max_payload_size: 8 * 1024 * 1024, + request_timeout_secs: 60, + worker_startup_timeout_secs: 5, + worker_startup_check_interval_secs: 1, + dp_aware: false, + api_key: None, + discovery: None, + metrics: None, + log_dir: None, + log_level: Some("warn".to_string()), + request_id_headers: None, + max_concurrent_requests: 32, + queue_size: 0, + queue_timeout_secs: 5, + rate_limit_tokens_per_second: None, + cors_allowed_origins: vec![], + retry: RetryConfig::default(), + circuit_breaker: CircuitBreakerConfig::default(), + disable_retries: false, + disable_circuit_breaker: false, + health_check: HealthCheckConfig::default(), + enable_igw: false, + model_path: None, + tokenizer_path: None, + history_backend: sglang_router_rs::config::HistoryBackend::Memory, + oracle: None, + }; + + // Create router and context + let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 64, None).expect("ctx"); + let router = RouterFactory::create_router(&Arc::new(ctx)) + .await + .expect("router"); + + // Build a simple ResponsesRequest that will trigger the tool call + let req = ResponsesRequest { + background: false, + include: None, + input: ResponseInput::Text("search something".to_string()), + instructions: Some("Be brief".to_string()), + max_output_tokens: Some(64), + max_tool_calls: None, + metadata: None, + model: Some("mock-model".to_string()), + parallel_tool_calls: true, + previous_response_id: None, + reasoning: None, + service_tier: sglang_router_rs::protocols::spec::ServiceTier::Auto, + store: true, + stream: false, + temperature: Some(0.2), + tool_choice: sglang_router_rs::protocols::spec::ToolChoice::default(), + tools: vec![ResponseTool { + r#type: ResponseToolType::Mcp, + server_url: Some(mcp.url()), + authorization: None, + server_label: Some("mock".to_string()), + server_description: None, + require_approval: None, + allowed_tools: None, + }], + top_logprobs: 0, + top_p: None, + truncation: sglang_router_rs::protocols::spec::Truncation::Disabled, + user: None, + request_id: "resp_test_mcp_e2e".to_string(), + priority: 0, + frequency_penalty: 0.0, + presence_penalty: 0.0, + stop: None, + top_k: -1, + min_p: 0.0, + repetition_penalty: 1.0, + }; + + let resp = router + .route_responses(None, &req, req.model.as_deref()) + .await; + + assert_eq!(resp.status(), axum::http::StatusCode::OK); + + let body_bytes = axum::body::to_bytes(resp.into_body(), usize::MAX) + .await + .expect("Failed to read response body"); + let body_json: serde_json::Value = + serde_json::from_slice(&body_bytes).expect("Failed to parse response JSON"); + + let output = body_json + .get("output") + .and_then(|v| v.as_array()) + .expect("response output missing"); + assert!(!output.is_empty(), "expected at least one output item"); + + let final_text = output + .iter() + .rev() + .filter_map(|entry| entry.get("content")) + .filter_map(|content| content.as_array()) + .flat_map(|parts| parts.iter()) + .filter_map(|part| part.get("text")) + .filter_map(|v| v.as_str()) + .next(); + + if let Some(text) = final_text { + assert_eq!(text, "Tool result consumed; here is the final answer."); + } else { + let call_entry = output.iter().find(|entry| { + entry.get("type") == Some(&serde_json::Value::String("function_tool_call".into())) + }); + assert!(call_entry.is_some(), "missing function tool call entry"); + if let Some(entry) = call_entry { + assert_eq!( + entry.get("status").and_then(|v| v.as_str()), + Some("in_progress"), + "function call should be in progress when no content is returned" + ); + } + } + + let tools = body_json + .get("tools") + .and_then(|v| v.as_array()) + .expect("tools array missing"); + assert_eq!(tools.len(), 1); + let tool = tools.first().unwrap(); + assert_eq!(tool.get("type").and_then(|v| v.as_str()), Some("mcp")); + assert_eq!( + tool.get("server_label").and_then(|v| v.as_str()), + Some("mock") + ); + + // Cleanup + worker.stop().await; + mcp.stop().await; +} + #[test] fn test_responses_request_creation() { let request = ResponsesRequest { @@ -29,6 +209,7 @@ fn test_responses_request_creation() { tool_choice: ToolChoice::Value(ToolChoiceValue::Auto), tools: vec![ResponseTool { r#type: ResponseToolType::WebSearchPreview, + ..Default::default() }], top_logprobs: 5, top_p: Some(0.9), @@ -179,6 +360,7 @@ fn test_json_serialization() { tool_choice: ToolChoice::Value(ToolChoiceValue::Required), tools: vec![ResponseTool { r#type: ResponseToolType::CodeInterpreter, + ..Default::default() }], top_logprobs: 10, top_p: Some(0.8),