[router] Steaming support for MCP Tool Calls in OpenAI Router (#11173)
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -608,29 +608,353 @@ async fn responses_handler(
|
||||
if is_stream {
|
||||
let request_id = format!("resp-{}", Uuid::new_v4());
|
||||
|
||||
let stream = stream::once(async move {
|
||||
let chunk = json!({
|
||||
"id": request_id,
|
||||
"object": "response",
|
||||
"created_at": timestamp,
|
||||
"model": "mock-model",
|
||||
"status": "in_progress",
|
||||
"output": [{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{
|
||||
"type": "output_text",
|
||||
"text": "This is a mock responses streamed output."
|
||||
}]
|
||||
}]
|
||||
});
|
||||
Ok::<_, Infallible>(Event::default().data(chunk.to_string()))
|
||||
})
|
||||
.chain(stream::once(async { Ok(Event::default().data("[DONE]")) }));
|
||||
// Check if this is an MCP tool call scenario
|
||||
let has_tools = payload
|
||||
.get("tools")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter().any(|tool| {
|
||||
tool.get("type")
|
||||
.and_then(|t| t.as_str())
|
||||
.map(|t| t == "function")
|
||||
.unwrap_or(false)
|
||||
})
|
||||
})
|
||||
.unwrap_or(false);
|
||||
let has_function_output = payload
|
||||
.get("input")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|items| {
|
||||
items.iter().any(|item| {
|
||||
item.get("type")
|
||||
.and_then(|t| t.as_str())
|
||||
.map(|t| t == "function_call_output")
|
||||
.unwrap_or(false)
|
||||
})
|
||||
})
|
||||
.unwrap_or(false);
|
||||
|
||||
Sse::new(stream)
|
||||
.keep_alive(KeepAlive::default())
|
||||
.into_response()
|
||||
if has_tools && !has_function_output {
|
||||
// First turn: emit streaming tool call events
|
||||
let call_id = format!(
|
||||
"call_{}",
|
||||
Uuid::new_v4().to_string().split('-').next().unwrap()
|
||||
);
|
||||
let rid = request_id.clone();
|
||||
|
||||
let events = vec![
|
||||
// response.created
|
||||
Ok::<_, Infallible>(
|
||||
Event::default().event("response.created").data(
|
||||
json!({
|
||||
"type": "response.created",
|
||||
"response": {
|
||||
"id": rid.clone(),
|
||||
"object": "response",
|
||||
"created_at": timestamp,
|
||||
"model": "mock-model",
|
||||
"status": "in_progress"
|
||||
}
|
||||
})
|
||||
.to_string(),
|
||||
),
|
||||
),
|
||||
// response.in_progress
|
||||
Ok(Event::default().event("response.in_progress").data(
|
||||
json!({
|
||||
"type": "response.in_progress",
|
||||
"response": {
|
||||
"id": rid.clone(),
|
||||
"object": "response",
|
||||
"created_at": timestamp,
|
||||
"model": "mock-model",
|
||||
"status": "in_progress"
|
||||
}
|
||||
})
|
||||
.to_string(),
|
||||
)),
|
||||
// response.output_item.added with function_tool_call
|
||||
Ok(Event::default().event("response.output_item.added").data(
|
||||
json!({
|
||||
"type": "response.output_item.added",
|
||||
"output_index": 0,
|
||||
"item": {
|
||||
"id": call_id.clone(),
|
||||
"type": "function_tool_call",
|
||||
"name": "brave_web_search",
|
||||
"arguments": "",
|
||||
"status": "in_progress"
|
||||
}
|
||||
})
|
||||
.to_string(),
|
||||
)),
|
||||
// response.function_call_arguments.delta events
|
||||
Ok(Event::default()
|
||||
.event("response.function_call_arguments.delta")
|
||||
.data(
|
||||
json!({
|
||||
"type": "response.function_call_arguments.delta",
|
||||
"output_index": 0,
|
||||
"item_id": call_id.clone(),
|
||||
"delta": "{\"query\""
|
||||
})
|
||||
.to_string(),
|
||||
)),
|
||||
Ok(Event::default()
|
||||
.event("response.function_call_arguments.delta")
|
||||
.data(
|
||||
json!({
|
||||
"type": "response.function_call_arguments.delta",
|
||||
"output_index": 0,
|
||||
"item_id": call_id.clone(),
|
||||
"delta": ":\"SGLang"
|
||||
})
|
||||
.to_string(),
|
||||
)),
|
||||
Ok(Event::default()
|
||||
.event("response.function_call_arguments.delta")
|
||||
.data(
|
||||
json!({
|
||||
"type": "response.function_call_arguments.delta",
|
||||
"output_index": 0,
|
||||
"item_id": call_id.clone(),
|
||||
"delta": " router MCP"
|
||||
})
|
||||
.to_string(),
|
||||
)),
|
||||
Ok(Event::default()
|
||||
.event("response.function_call_arguments.delta")
|
||||
.data(
|
||||
json!({
|
||||
"type": "response.function_call_arguments.delta",
|
||||
"output_index": 0,
|
||||
"item_id": call_id.clone(),
|
||||
"delta": " integration\"}"
|
||||
})
|
||||
.to_string(),
|
||||
)),
|
||||
// response.function_call_arguments.done
|
||||
Ok(Event::default()
|
||||
.event("response.function_call_arguments.done")
|
||||
.data(
|
||||
json!({
|
||||
"type": "response.function_call_arguments.done",
|
||||
"output_index": 0,
|
||||
"item_id": call_id.clone()
|
||||
})
|
||||
.to_string(),
|
||||
)),
|
||||
// response.output_item.done
|
||||
Ok(Event::default().event("response.output_item.done").data(
|
||||
json!({
|
||||
"type": "response.output_item.done",
|
||||
"output_index": 0,
|
||||
"item": {
|
||||
"id": call_id.clone(),
|
||||
"type": "function_tool_call",
|
||||
"name": "brave_web_search",
|
||||
"arguments": "{\"query\":\"SGLang router MCP integration\"}",
|
||||
"status": "completed"
|
||||
}
|
||||
})
|
||||
.to_string(),
|
||||
)),
|
||||
// response.completed
|
||||
Ok(Event::default().event("response.completed").data(
|
||||
json!({
|
||||
"type": "response.completed",
|
||||
"response": {
|
||||
"id": rid,
|
||||
"object": "response",
|
||||
"created_at": timestamp,
|
||||
"model": "mock-model",
|
||||
"status": "completed"
|
||||
}
|
||||
})
|
||||
.to_string(),
|
||||
)),
|
||||
// [DONE]
|
||||
Ok(Event::default().data("[DONE]")),
|
||||
];
|
||||
|
||||
let stream = stream::iter(events);
|
||||
Sse::new(stream)
|
||||
.keep_alive(KeepAlive::default())
|
||||
.into_response()
|
||||
} else if has_tools && has_function_output {
|
||||
// Second turn: emit streaming text response
|
||||
let rid = request_id.clone();
|
||||
let msg_id = format!(
|
||||
"msg_{}",
|
||||
Uuid::new_v4().to_string().split('-').next().unwrap()
|
||||
);
|
||||
|
||||
let events = vec![
|
||||
// response.created
|
||||
Ok::<_, Infallible>(
|
||||
Event::default().event("response.created").data(
|
||||
json!({
|
||||
"type": "response.created",
|
||||
"response": {
|
||||
"id": rid.clone(),
|
||||
"object": "response",
|
||||
"created_at": timestamp,
|
||||
"model": "mock-model",
|
||||
"status": "in_progress"
|
||||
}
|
||||
})
|
||||
.to_string(),
|
||||
),
|
||||
),
|
||||
// response.in_progress
|
||||
Ok(Event::default().event("response.in_progress").data(
|
||||
json!({
|
||||
"type": "response.in_progress",
|
||||
"response": {
|
||||
"id": rid.clone(),
|
||||
"object": "response",
|
||||
"created_at": timestamp,
|
||||
"model": "mock-model",
|
||||
"status": "in_progress"
|
||||
}
|
||||
})
|
||||
.to_string(),
|
||||
)),
|
||||
// response.output_item.added with message
|
||||
Ok(Event::default().event("response.output_item.added").data(
|
||||
json!({
|
||||
"type": "response.output_item.added",
|
||||
"output_index": 0,
|
||||
"item": {
|
||||
"id": msg_id.clone(),
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": []
|
||||
}
|
||||
})
|
||||
.to_string(),
|
||||
)),
|
||||
// response.content_part.added
|
||||
Ok(Event::default().event("response.content_part.added").data(
|
||||
json!({
|
||||
"type": "response.content_part.added",
|
||||
"output_index": 0,
|
||||
"item_id": msg_id.clone(),
|
||||
"part": {
|
||||
"type": "output_text",
|
||||
"text": ""
|
||||
}
|
||||
})
|
||||
.to_string(),
|
||||
)),
|
||||
// response.output_text.delta events
|
||||
Ok(Event::default().event("response.output_text.delta").data(
|
||||
json!({
|
||||
"type": "response.output_text.delta",
|
||||
"output_index": 0,
|
||||
"content_index": 0,
|
||||
"delta": "Tool result"
|
||||
})
|
||||
.to_string(),
|
||||
)),
|
||||
Ok(Event::default().event("response.output_text.delta").data(
|
||||
json!({
|
||||
"type": "response.output_text.delta",
|
||||
"output_index": 0,
|
||||
"content_index": 0,
|
||||
"delta": " consumed;"
|
||||
})
|
||||
.to_string(),
|
||||
)),
|
||||
Ok(Event::default().event("response.output_text.delta").data(
|
||||
json!({
|
||||
"type": "response.output_text.delta",
|
||||
"output_index": 0,
|
||||
"content_index": 0,
|
||||
"delta": " here is the final answer."
|
||||
})
|
||||
.to_string(),
|
||||
)),
|
||||
// response.output_text.done
|
||||
Ok(Event::default().event("response.output_text.done").data(
|
||||
json!({
|
||||
"type": "response.output_text.done",
|
||||
"output_index": 0,
|
||||
"content_index": 0,
|
||||
"text": "Tool result consumed; here is the final answer."
|
||||
})
|
||||
.to_string(),
|
||||
)),
|
||||
// response.output_item.done
|
||||
Ok(Event::default().event("response.output_item.done").data(
|
||||
json!({
|
||||
"type": "response.output_item.done",
|
||||
"output_index": 0,
|
||||
"item": {
|
||||
"id": msg_id,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{
|
||||
"type": "output_text",
|
||||
"text": "Tool result consumed; here is the final answer."
|
||||
}]
|
||||
}
|
||||
})
|
||||
.to_string(),
|
||||
)),
|
||||
// response.completed
|
||||
Ok(Event::default().event("response.completed").data(
|
||||
json!({
|
||||
"type": "response.completed",
|
||||
"response": {
|
||||
"id": rid,
|
||||
"object": "response",
|
||||
"created_at": timestamp,
|
||||
"model": "mock-model",
|
||||
"status": "completed",
|
||||
"usage": {
|
||||
"input_tokens": 12,
|
||||
"output_tokens": 7,
|
||||
"total_tokens": 19
|
||||
}
|
||||
}
|
||||
})
|
||||
.to_string(),
|
||||
)),
|
||||
// [DONE]
|
||||
Ok(Event::default().data("[DONE]")),
|
||||
];
|
||||
|
||||
let stream = stream::iter(events);
|
||||
Sse::new(stream)
|
||||
.keep_alive(KeepAlive::default())
|
||||
.into_response()
|
||||
} else {
|
||||
// Default streaming response
|
||||
let stream = stream::once(async move {
|
||||
let chunk = json!({
|
||||
"id": request_id,
|
||||
"object": "response",
|
||||
"created_at": timestamp,
|
||||
"model": "mock-model",
|
||||
"status": "in_progress",
|
||||
"output": [{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{
|
||||
"type": "output_text",
|
||||
"text": "This is a mock responses streamed output."
|
||||
}]
|
||||
}]
|
||||
});
|
||||
Ok::<_, Infallible>(Event::default().data(chunk.to_string()))
|
||||
})
|
||||
.chain(stream::once(async { Ok(Event::default().data("[DONE]")) }));
|
||||
|
||||
Sse::new(stream)
|
||||
.keep_alive(KeepAlive::default())
|
||||
.into_response()
|
||||
}
|
||||
} else if is_background {
|
||||
let rid = req_id.unwrap_or_else(|| format!("resp-{}", Uuid::new_v4()));
|
||||
Json(json!({
|
||||
|
||||
@@ -765,3 +765,464 @@ async fn test_max_tool_calls_limit() {
|
||||
worker.stop().await;
|
||||
mcp.stop().await;
|
||||
}
|
||||
|
||||
/// Helper function to set up common test infrastructure for streaming MCP tests
|
||||
/// Returns (mcp_server, worker, router, temp_dir)
|
||||
async fn setup_streaming_mcp_test() -> (
|
||||
MockMCPServer,
|
||||
MockWorker,
|
||||
Box<dyn sglang_router_rs::routers::RouterTrait>,
|
||||
tempfile::TempDir,
|
||||
) {
|
||||
let mcp = MockMCPServer::start().await.expect("start mcp");
|
||||
let mcp_yaml = format!(
|
||||
"servers:\n - name: mock\n protocol: streamable\n url: {}\n",
|
||||
mcp.url()
|
||||
);
|
||||
let dir = tempfile::tempdir().expect("tmpdir");
|
||||
let cfg_path = dir.path().join("mcp.yaml");
|
||||
std::fs::write(&cfg_path, mcp_yaml).expect("write mcp cfg");
|
||||
|
||||
let mut worker = MockWorker::new(MockWorkerConfig {
|
||||
port: 0,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 0.0,
|
||||
});
|
||||
let worker_url = worker.start().await.expect("start worker");
|
||||
|
||||
let router_cfg = RouterConfig {
|
||||
mode: RoutingMode::OpenAI {
|
||||
worker_urls: vec![worker_url],
|
||||
},
|
||||
connection_mode: ConnectionMode::Http,
|
||||
policy: PolicyConfig::Random,
|
||||
host: "127.0.0.1".to_string(),
|
||||
port: 0,
|
||||
max_payload_size: 8 * 1024 * 1024,
|
||||
request_timeout_secs: 60,
|
||||
worker_startup_timeout_secs: 5,
|
||||
worker_startup_check_interval_secs: 1,
|
||||
dp_aware: false,
|
||||
api_key: None,
|
||||
discovery: None,
|
||||
metrics: None,
|
||||
log_dir: None,
|
||||
log_level: Some("info".to_string()),
|
||||
request_id_headers: None,
|
||||
max_concurrent_requests: 32,
|
||||
queue_size: 0,
|
||||
queue_timeout_secs: 5,
|
||||
rate_limit_tokens_per_second: None,
|
||||
cors_allowed_origins: vec![],
|
||||
retry: RetryConfig::default(),
|
||||
circuit_breaker: CircuitBreakerConfig::default(),
|
||||
disable_retries: false,
|
||||
disable_circuit_breaker: false,
|
||||
health_check: HealthCheckConfig::default(),
|
||||
enable_igw: false,
|
||||
model_path: None,
|
||||
tokenizer_path: None,
|
||||
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
|
||||
oracle: None,
|
||||
};
|
||||
|
||||
let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 64, None).expect("ctx");
|
||||
let router = RouterFactory::create_router(&Arc::new(ctx))
|
||||
.await
|
||||
.expect("router");
|
||||
|
||||
(mcp, worker, router, dir)
|
||||
}
|
||||
|
||||
/// Parse SSE (Server-Sent Events) stream into structured events
|
||||
fn parse_sse_events(body: &str) -> Vec<(Option<String>, serde_json::Value)> {
|
||||
let mut events = Vec::new();
|
||||
let blocks: Vec<&str> = body
|
||||
.split("\n\n")
|
||||
.filter(|s| !s.trim().is_empty())
|
||||
.collect();
|
||||
|
||||
for block in blocks {
|
||||
let mut event_name: Option<String> = None;
|
||||
let mut data_lines: Vec<String> = Vec::new();
|
||||
|
||||
for line in block.lines() {
|
||||
if let Some(rest) = line.strip_prefix("event:") {
|
||||
event_name = Some(rest.trim().to_string());
|
||||
} else if let Some(rest) = line.strip_prefix("data:") {
|
||||
let data = rest.trim_start();
|
||||
// Skip [DONE] marker
|
||||
if data != "[DONE]" {
|
||||
data_lines.push(data.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !data_lines.is_empty() {
|
||||
let data = data_lines.join("\n");
|
||||
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&data) {
|
||||
events.push((event_name, parsed));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
events
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_streaming_with_mcp_tool_calls() {
|
||||
// This test verifies that streaming works with MCP tool calls:
|
||||
// 1. Initial streaming request with MCP tools
|
||||
// 2. Mock worker streams text, then function_call deltas
|
||||
// 3. Router buffers function call, executes MCP tool
|
||||
// 4. Router resumes streaming with tool results
|
||||
// 5. Mock worker streams final answer
|
||||
// 6. Verify SSE events are properly formatted
|
||||
|
||||
let (mut mcp, mut worker, router, _dir) = setup_streaming_mcp_test().await;
|
||||
|
||||
// Build streaming request with MCP tools
|
||||
let req = ResponsesRequest {
|
||||
background: false,
|
||||
include: None,
|
||||
input: ResponseInput::Text("search for something interesting".to_string()),
|
||||
instructions: Some("Use tools when needed".to_string()),
|
||||
max_output_tokens: Some(256),
|
||||
max_tool_calls: Some(3),
|
||||
metadata: None,
|
||||
model: Some("mock-model".to_string()),
|
||||
parallel_tool_calls: true,
|
||||
previous_response_id: None,
|
||||
reasoning: None,
|
||||
service_tier: ServiceTier::Auto,
|
||||
store: true,
|
||||
stream: true, // KEY: Enable streaming
|
||||
temperature: Some(0.7),
|
||||
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto),
|
||||
tools: vec![ResponseTool {
|
||||
r#type: ResponseToolType::Mcp,
|
||||
server_url: Some(mcp.url()),
|
||||
server_label: Some("mock".to_string()),
|
||||
server_description: Some("Mock MCP for streaming test".to_string()),
|
||||
require_approval: Some("never".to_string()),
|
||||
..Default::default()
|
||||
}],
|
||||
top_logprobs: 0,
|
||||
top_p: Some(1.0),
|
||||
truncation: Truncation::Disabled,
|
||||
user: None,
|
||||
request_id: "resp_streaming_mcp_test".to_string(),
|
||||
priority: 0,
|
||||
frequency_penalty: 0.0,
|
||||
presence_penalty: 0.0,
|
||||
stop: None,
|
||||
top_k: 50,
|
||||
min_p: 0.0,
|
||||
repetition_penalty: 1.0,
|
||||
};
|
||||
|
||||
let response = router.route_responses(None, &req, None).await;
|
||||
|
||||
// Verify streaming response
|
||||
assert_eq!(
|
||||
response.status(),
|
||||
axum::http::StatusCode::OK,
|
||||
"Streaming request should succeed"
|
||||
);
|
||||
|
||||
// Check Content-Type is text/event-stream
|
||||
let content_type = response
|
||||
.headers()
|
||||
.get("content-type")
|
||||
.and_then(|v| v.to_str().ok());
|
||||
assert_eq!(
|
||||
content_type,
|
||||
Some("text/event-stream"),
|
||||
"Should have SSE content type"
|
||||
);
|
||||
|
||||
// Read the streaming body
|
||||
use axum::body::to_bytes;
|
||||
let response_body = response.into_body();
|
||||
let body_bytes = to_bytes(response_body, usize::MAX).await.unwrap();
|
||||
let body_text = String::from_utf8_lossy(&body_bytes);
|
||||
|
||||
println!("Streaming SSE response:\n{}", body_text);
|
||||
|
||||
// Parse all SSE events into structured format
|
||||
let events = parse_sse_events(&body_text);
|
||||
|
||||
assert!(!events.is_empty(), "Should have at least one SSE event");
|
||||
println!("Total parsed SSE events: {}", events.len());
|
||||
|
||||
// Check for [DONE] marker
|
||||
let has_done_marker = body_text.contains("data: [DONE]");
|
||||
assert!(has_done_marker, "Stream should end with [DONE] marker");
|
||||
|
||||
// Track which events we've seen
|
||||
let mut found_mcp_list_tools = false;
|
||||
let mut found_mcp_list_tools_in_progress = false;
|
||||
let mut found_mcp_list_tools_completed = false;
|
||||
let mut found_response_created = false;
|
||||
let mut found_mcp_call_added = false;
|
||||
let mut found_mcp_call_in_progress = false;
|
||||
let mut found_mcp_call_arguments = false;
|
||||
let mut found_mcp_call_arguments_done = false;
|
||||
let mut found_mcp_call_done = false;
|
||||
let mut found_response_completed = false;
|
||||
|
||||
for (event_name, data) in &events {
|
||||
let event_type = data.get("type").and_then(|v| v.as_str()).unwrap_or("");
|
||||
|
||||
match event_type {
|
||||
"response.output_item.added" => {
|
||||
// Check if it's an mcp_list_tools item
|
||||
if let Some(item) = data.get("item") {
|
||||
if item.get("type").and_then(|v| v.as_str()) == Some("mcp_list_tools") {
|
||||
found_mcp_list_tools = true;
|
||||
println!("✓ Found mcp_list_tools added event");
|
||||
|
||||
// Verify tools array is present (should be empty in added event)
|
||||
assert!(
|
||||
item.get("tools").is_some(),
|
||||
"mcp_list_tools should have tools array"
|
||||
);
|
||||
} else if item.get("type").and_then(|v| v.as_str()) == Some("mcp_call") {
|
||||
found_mcp_call_added = true;
|
||||
println!("✓ Found mcp_call added event");
|
||||
|
||||
// Verify mcp_call has required fields
|
||||
assert!(item.get("name").is_some(), "mcp_call should have name");
|
||||
assert_eq!(
|
||||
item.get("server_label").and_then(|v| v.as_str()),
|
||||
Some("mock"),
|
||||
"mcp_call should have server_label"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
"response.mcp_list_tools.in_progress" => {
|
||||
found_mcp_list_tools_in_progress = true;
|
||||
println!("✓ Found mcp_list_tools.in_progress event");
|
||||
|
||||
// Verify it has output_index and item_id
|
||||
assert!(
|
||||
data.get("output_index").is_some(),
|
||||
"mcp_list_tools.in_progress should have output_index"
|
||||
);
|
||||
assert!(
|
||||
data.get("item_id").is_some(),
|
||||
"mcp_list_tools.in_progress should have item_id"
|
||||
);
|
||||
}
|
||||
"response.mcp_list_tools.completed" => {
|
||||
found_mcp_list_tools_completed = true;
|
||||
println!("✓ Found mcp_list_tools.completed event");
|
||||
|
||||
// Verify it has output_index and item_id
|
||||
assert!(
|
||||
data.get("output_index").is_some(),
|
||||
"mcp_list_tools.completed should have output_index"
|
||||
);
|
||||
assert!(
|
||||
data.get("item_id").is_some(),
|
||||
"mcp_list_tools.completed should have item_id"
|
||||
);
|
||||
}
|
||||
"response.mcp_call.in_progress" => {
|
||||
found_mcp_call_in_progress = true;
|
||||
println!("✓ Found mcp_call.in_progress event");
|
||||
|
||||
// Verify it has output_index and item_id
|
||||
assert!(
|
||||
data.get("output_index").is_some(),
|
||||
"mcp_call.in_progress should have output_index"
|
||||
);
|
||||
assert!(
|
||||
data.get("item_id").is_some(),
|
||||
"mcp_call.in_progress should have item_id"
|
||||
);
|
||||
}
|
||||
"response.mcp_call_arguments.delta" => {
|
||||
found_mcp_call_arguments = true;
|
||||
println!("✓ Found mcp_call_arguments.delta event");
|
||||
|
||||
// Delta should include arguments payload
|
||||
assert!(
|
||||
data.get("delta").is_some(),
|
||||
"mcp_call_arguments.delta should include delta text"
|
||||
);
|
||||
}
|
||||
"response.mcp_call_arguments.done" => {
|
||||
found_mcp_call_arguments_done = true;
|
||||
println!("✓ Found mcp_call_arguments.done event");
|
||||
|
||||
assert!(
|
||||
data.get("arguments").is_some(),
|
||||
"mcp_call_arguments.done should include full arguments"
|
||||
);
|
||||
}
|
||||
"response.output_item.done" => {
|
||||
if let Some(item) = data.get("item") {
|
||||
if item.get("type").and_then(|v| v.as_str()) == Some("mcp_call") {
|
||||
found_mcp_call_done = true;
|
||||
println!("✓ Found mcp_call done event");
|
||||
|
||||
// Verify mcp_call.done has output
|
||||
assert!(
|
||||
item.get("output").is_some(),
|
||||
"mcp_call done should have output"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
"response.created" => {
|
||||
found_response_created = true;
|
||||
println!("✓ Found response.created event");
|
||||
|
||||
// Verify response has required fields
|
||||
assert!(
|
||||
data.get("response").is_some(),
|
||||
"response.created should have response object"
|
||||
);
|
||||
}
|
||||
"response.completed" => {
|
||||
found_response_completed = true;
|
||||
println!("✓ Found response.completed event");
|
||||
}
|
||||
_ => {
|
||||
println!(" Other event: {}", event_type);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(name) = event_name {
|
||||
println!(" Event name: {}", name);
|
||||
}
|
||||
}
|
||||
|
||||
// Verify key events were present
|
||||
println!("\n=== Event Summary ===");
|
||||
println!("MCP list_tools added: {}", found_mcp_list_tools);
|
||||
println!(
|
||||
"MCP list_tools in_progress: {}",
|
||||
found_mcp_list_tools_in_progress
|
||||
);
|
||||
println!(
|
||||
"MCP list_tools completed: {}",
|
||||
found_mcp_list_tools_completed
|
||||
);
|
||||
println!("Response created: {}", found_response_created);
|
||||
println!("MCP call added: {}", found_mcp_call_added);
|
||||
println!("MCP call in_progress: {}", found_mcp_call_in_progress);
|
||||
println!("MCP call arguments delta: {}", found_mcp_call_arguments);
|
||||
println!("MCP call arguments done: {}", found_mcp_call_arguments_done);
|
||||
println!("MCP call done: {}", found_mcp_call_done);
|
||||
println!("Response completed: {}", found_response_completed);
|
||||
|
||||
// Assert critical events are present
|
||||
assert!(
|
||||
found_mcp_list_tools,
|
||||
"Should send mcp_list_tools added event at the start"
|
||||
);
|
||||
assert!(
|
||||
found_mcp_list_tools_in_progress,
|
||||
"Should send mcp_list_tools.in_progress event"
|
||||
);
|
||||
assert!(
|
||||
found_mcp_list_tools_completed,
|
||||
"Should send mcp_list_tools.completed event"
|
||||
);
|
||||
assert!(found_response_created, "Should send response.created event");
|
||||
assert!(found_mcp_call_added, "Should send mcp_call added event");
|
||||
assert!(
|
||||
found_mcp_call_in_progress,
|
||||
"Should send mcp_call.in_progress event"
|
||||
);
|
||||
assert!(found_mcp_call_done, "Should send mcp_call done event");
|
||||
|
||||
assert!(
|
||||
found_mcp_call_arguments,
|
||||
"Should send mcp_call_arguments.delta event"
|
||||
);
|
||||
assert!(
|
||||
found_mcp_call_arguments_done,
|
||||
"Should send mcp_call_arguments.done event"
|
||||
);
|
||||
|
||||
// Verify no error events
|
||||
let has_error = body_text.contains("event: error");
|
||||
assert!(!has_error, "Should not have error events");
|
||||
|
||||
worker.stop().await;
|
||||
mcp.stop().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_streaming_multi_turn_with_mcp() {
|
||||
// Test streaming with multiple tool call rounds
|
||||
let (mut mcp, mut worker, router, _dir) = setup_streaming_mcp_test().await;
|
||||
|
||||
let req = ResponsesRequest {
|
||||
background: false,
|
||||
include: None,
|
||||
input: ResponseInput::Text("complex query requiring multiple tool calls".to_string()),
|
||||
instructions: Some("Be thorough".to_string()),
|
||||
max_output_tokens: Some(512),
|
||||
max_tool_calls: Some(5), // Allow multiple rounds
|
||||
metadata: None,
|
||||
model: Some("mock-model".to_string()),
|
||||
parallel_tool_calls: true,
|
||||
previous_response_id: None,
|
||||
reasoning: None,
|
||||
service_tier: ServiceTier::Auto,
|
||||
store: true,
|
||||
stream: true,
|
||||
temperature: Some(0.8),
|
||||
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto),
|
||||
tools: vec![ResponseTool {
|
||||
r#type: ResponseToolType::Mcp,
|
||||
server_url: Some(mcp.url()),
|
||||
server_label: Some("mock".to_string()),
|
||||
..Default::default()
|
||||
}],
|
||||
top_logprobs: 0,
|
||||
top_p: Some(1.0),
|
||||
truncation: Truncation::Disabled,
|
||||
user: None,
|
||||
request_id: "resp_streaming_multiturn_test".to_string(),
|
||||
priority: 0,
|
||||
frequency_penalty: 0.0,
|
||||
presence_penalty: 0.0,
|
||||
stop: None,
|
||||
top_k: 50,
|
||||
min_p: 0.0,
|
||||
repetition_penalty: 1.0,
|
||||
};
|
||||
|
||||
let response = router.route_responses(None, &req, None).await;
|
||||
assert_eq!(response.status(), axum::http::StatusCode::OK);
|
||||
|
||||
use axum::body::to_bytes;
|
||||
let body_bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap();
|
||||
let body_text = String::from_utf8_lossy(&body_bytes);
|
||||
|
||||
println!("Multi-turn streaming response:\n{}", body_text);
|
||||
|
||||
// Verify streaming completed successfully
|
||||
assert!(body_text.contains("data: [DONE]"));
|
||||
assert!(!body_text.contains("event: error"));
|
||||
|
||||
// Count events
|
||||
let event_count = body_text
|
||||
.split("\n\n")
|
||||
.filter(|s| !s.trim().is_empty())
|
||||
.count();
|
||||
println!("Total events in multi-turn stream: {}", event_count);
|
||||
|
||||
assert!(event_count > 0, "Should have received streaming events");
|
||||
|
||||
worker.stop().await;
|
||||
mcp.stop().await;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user