From 84768d10175cb147da1e54b0eeb1e626ff50c276 Mon Sep 17 00:00:00 2001 From: Keyang Ru Date: Wed, 8 Oct 2025 21:46:39 -0700 Subject: [PATCH] [router] Refactor OpenAI router: split monolithic file and move location (#11359) --- sgl-router/src/routers/factory.rs | 3 +- sgl-router/src/routers/http/mod.rs | 1 - sgl-router/src/routers/http/openai_router.rs | 4547 ----------------- sgl-router/src/routers/mod.rs | 5 +- .../src/routers/openai/conversations.rs | 574 +++ sgl-router/src/routers/openai/mcp.rs | 967 ++++ sgl-router/src/routers/openai/mod.rs | 18 + sgl-router/src/routers/openai/responses.rs | 368 ++ sgl-router/src/routers/openai/router.rs | 909 ++++ sgl-router/src/routers/openai/streaming.rs | 1550 ++++++ sgl-router/src/routers/openai/utils.rs | 100 + sgl-router/tests/test_openai_routing.rs | 2 +- 12 files changed, 4492 insertions(+), 4552 deletions(-) delete mode 100644 sgl-router/src/routers/http/openai_router.rs create mode 100644 sgl-router/src/routers/openai/conversations.rs create mode 100644 sgl-router/src/routers/openai/mcp.rs create mode 100644 sgl-router/src/routers/openai/mod.rs create mode 100644 sgl-router/src/routers/openai/responses.rs create mode 100644 sgl-router/src/routers/openai/router.rs create mode 100644 sgl-router/src/routers/openai/streaming.rs create mode 100644 sgl-router/src/routers/openai/utils.rs diff --git a/sgl-router/src/routers/factory.rs b/sgl-router/src/routers/factory.rs index eecf8f839..5a00fa7f5 100644 --- a/sgl-router/src/routers/factory.rs +++ b/sgl-router/src/routers/factory.rs @@ -3,7 +3,8 @@ use super::grpc::pd_router::GrpcPDRouter; use super::grpc::router::GrpcRouter; use super::{ - http::{openai_router::OpenAIRouter, pd_router::PDRouter, router::Router}, + http::{pd_router::PDRouter, router::Router}, + openai::OpenAIRouter, RouterTrait, }; use crate::config::{ConnectionMode, PolicyConfig, RoutingMode}; diff --git a/sgl-router/src/routers/http/mod.rs b/sgl-router/src/routers/http/mod.rs index 9f955b651..3f31b6f86 100644 --- a/sgl-router/src/routers/http/mod.rs +++ b/sgl-router/src/routers/http/mod.rs @@ -1,6 +1,5 @@ //! HTTP router implementations -pub mod openai_router; pub mod pd_router; pub mod pd_types; pub mod router; diff --git a/sgl-router/src/routers/http/openai_router.rs b/sgl-router/src/routers/http/openai_router.rs deleted file mode 100644 index a18ed3da6..000000000 --- a/sgl-router/src/routers/http/openai_router.rs +++ /dev/null @@ -1,4547 +0,0 @@ -//! OpenAI router implementation - -use crate::config::CircuitBreakerConfig; -use crate::core::{CircuitBreaker, CircuitBreakerConfig as CoreCircuitBreakerConfig}; -use crate::data_connector::{ - Conversation, ConversationId, ConversationItemsListParams, ConversationItemsSortOrder, - ConversationMetadata, NewConversationItem as DCNewConversationItem, ResponseId, - SharedConversationItemStorage, SharedConversationStorage, SharedResponseStorage, - StoredResponse, -}; -use crate::protocols::spec::{ - ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, - ResponseContentPart, ResponseInput, ResponseInputOutputItem, ResponseOutputItem, - ResponseStatus, ResponseTextFormat, ResponseTool, ResponseToolType, ResponsesGetParams, - ResponsesRequest, ResponsesResponse, TextFormatType, -}; -use crate::routers::header_utils::{apply_request_headers, preserve_response_headers}; -use async_trait::async_trait; -use axum::{ - body::Body, - extract::Request, - http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode}, - response::{IntoResponse, Response}, - Json, -}; -use bytes::Bytes; -use futures_util::StreamExt; -use serde_json::{json, to_value, Value}; -use std::{ - any::Any, - borrow::Cow, - collections::HashMap, - io, - sync::{atomic::AtomicBool, Arc}, -}; -use tokio::sync::mpsc; -use tokio_stream::wrappers::UnboundedReceiverStream; -use tracing::{error, info, warn}; - -// SSE Event Type Constants - single source of truth for event type strings -mod event_types { - // Response lifecycle events - pub const RESPONSE_CREATED: &str = "response.created"; - pub const RESPONSE_IN_PROGRESS: &str = "response.in_progress"; - pub const RESPONSE_COMPLETED: &str = "response.completed"; - - // Output item events - pub const OUTPUT_ITEM_ADDED: &str = "response.output_item.added"; - pub const OUTPUT_ITEM_DONE: &str = "response.output_item.done"; - pub const OUTPUT_ITEM_DELTA: &str = "response.output_item.delta"; - - // Function call events - pub const FUNCTION_CALL_ARGUMENTS_DELTA: &str = "response.function_call_arguments.delta"; - pub const FUNCTION_CALL_ARGUMENTS_DONE: &str = "response.function_call_arguments.done"; - - // MCP call events - pub const MCP_CALL_ARGUMENTS_DELTA: &str = "response.mcp_call_arguments.delta"; - pub const MCP_CALL_ARGUMENTS_DONE: &str = "response.mcp_call_arguments.done"; - pub const MCP_CALL_IN_PROGRESS: &str = "response.mcp_call.in_progress"; - pub const MCP_CALL_COMPLETED: &str = "response.mcp_call.completed"; - pub const MCP_LIST_TOOLS_IN_PROGRESS: &str = "response.mcp_list_tools.in_progress"; - pub const MCP_LIST_TOOLS_COMPLETED: &str = "response.mcp_list_tools.completed"; - - // Item types - pub const ITEM_TYPE_FUNCTION_CALL: &str = "function_call"; - pub const ITEM_TYPE_FUNCTION_TOOL_CALL: &str = "function_tool_call"; - pub const ITEM_TYPE_MCP_CALL: &str = "mcp_call"; - pub const ITEM_TYPE_FUNCTION: &str = "function"; - pub const ITEM_TYPE_MCP_LIST_TOOLS: &str = "mcp_list_tools"; -} - -/// Router for OpenAI backend -pub struct OpenAIRouter { - /// HTTP client for upstream OpenAI-compatible API - client: reqwest::Client, - /// Base URL for identification (no trailing slash) - base_url: String, - /// Circuit breaker - circuit_breaker: CircuitBreaker, - /// Health status - healthy: AtomicBool, - /// Response storage for managing conversation history - response_storage: SharedResponseStorage, - /// Conversation storage backend - conversation_storage: SharedConversationStorage, - /// Conversation item storage backend - conversation_item_storage: SharedConversationItemStorage, - /// Optional MCP manager (enabled via config presence) - mcp_manager: Option>, -} - -impl std::fmt::Debug for OpenAIRouter { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("OpenAIRouter") - .field("base_url", &self.base_url) - .field("healthy", &self.healthy) - .finish() - } -} - -/// Configuration for MCP tool calling loops -#[derive(Debug, Clone)] -struct McpLoopConfig { - /// Maximum iterations as safety limit (internal only, default: 10) - /// Prevents infinite loops when max_tool_calls is not set - max_iterations: usize, -} - -impl Default for McpLoopConfig { - fn default() -> Self { - Self { max_iterations: 10 } - } -} - -/// State for tracking multi-turn tool calling loop -struct ToolLoopState { - /// Current iteration number (starts at 0, increments with each tool call) - iteration: usize, - /// Total number of tool calls executed - total_calls: usize, - /// Conversation history (function_call and function_call_output items) - conversation_history: Vec, - /// Original user input (preserved for building resume payloads) - original_input: ResponseInput, -} - -impl ToolLoopState { - fn new(original_input: ResponseInput) -> Self { - Self { - iteration: 0, - total_calls: 0, - conversation_history: Vec::new(), - original_input, - } - } - - /// Record a tool call in the loop state - fn record_call( - &mut self, - call_id: String, - tool_name: String, - args_json_str: String, - output_str: String, - ) { - // Add function_call item to history - let func_item = json!({ - "type": event_types::ITEM_TYPE_FUNCTION_CALL, - "call_id": call_id, - "name": tool_name, - "arguments": args_json_str - }); - self.conversation_history.push(func_item); - - // Add function_call_output item to history - let output_item = json!({ - "type": "function_call_output", - "call_id": call_id, - "output": output_str - }); - self.conversation_history.push(output_item); - } -} - -/// Helper that parses SSE frames from the OpenAI responses stream and -/// accumulates enough information to persist the final response locally. -struct StreamingResponseAccumulator { - /// The initial `response.created` payload (if emitted). - initial_response: Option, - /// The final `response.completed` payload (if emitted). - completed_response: Option, - /// Collected output items keyed by the upstream output index, used when - /// a final response payload is absent and we need to synthesize one. - output_items: Vec<(usize, Value)>, - /// Captured error payload (if the upstream stream fails midway). - encountered_error: Option, -} - -/// Represents a function call being accumulated across delta events -#[derive(Debug, Clone)] -struct FunctionCallInProgress { - call_id: String, - name: String, - arguments_buffer: String, - output_index: usize, - last_obfuscation: Option, - assigned_output_index: Option, -} - -impl FunctionCallInProgress { - fn new(call_id: String, output_index: usize) -> Self { - Self { - call_id, - name: String::new(), - arguments_buffer: String::new(), - output_index, - last_obfuscation: None, - assigned_output_index: None, - } - } - - fn is_complete(&self) -> bool { - // A tool call is complete if it has a name - !self.name.is_empty() - } - - fn effective_output_index(&self) -> usize { - self.assigned_output_index.unwrap_or(self.output_index) - } -} - -#[derive(Debug, Default)] -struct OutputIndexMapper { - next_index: usize, - // Map upstream output_index -> remapped output_index - assigned: HashMap, -} - -impl OutputIndexMapper { - fn with_start(next_index: usize) -> Self { - Self { - next_index, - assigned: HashMap::new(), - } - } - - fn ensure_mapping(&mut self, upstream_index: usize) -> usize { - *self.assigned.entry(upstream_index).or_insert_with(|| { - let assigned = self.next_index; - self.next_index += 1; - assigned - }) - } - - fn lookup(&self, upstream_index: usize) -> Option { - self.assigned.get(&upstream_index).copied() - } - - fn allocate_synthetic(&mut self) -> usize { - let assigned = self.next_index; - self.next_index += 1; - assigned - } - - fn next_index(&self) -> usize { - self.next_index - } -} - -/// Action to take based on streaming event processing -#[derive(Debug)] -enum StreamAction { - Forward, // Pass event to client - Buffer, // Accumulate for tool execution - ExecuteTools, // Function call complete, execute now -} - -/// Handles streaming responses with MCP tool call interception -struct StreamingToolHandler { - /// Accumulator for response persistence - accumulator: StreamingResponseAccumulator, - /// Function calls being built from deltas - pending_calls: Vec, - /// Track if we're currently in a function call - in_function_call: bool, - /// Manage output_index remapping so they increment per item - output_index_mapper: OutputIndexMapper, - /// Original response id captured from the first response.created event - original_response_id: Option, -} - -impl StreamingToolHandler { - fn with_starting_index(start: usize) -> Self { - Self { - accumulator: StreamingResponseAccumulator::new(), - pending_calls: Vec::new(), - in_function_call: false, - output_index_mapper: OutputIndexMapper::with_start(start), - original_response_id: None, - } - } - - fn ensure_output_index(&mut self, upstream_index: usize) -> usize { - self.output_index_mapper.ensure_mapping(upstream_index) - } - - fn mapped_output_index(&self, upstream_index: usize) -> Option { - self.output_index_mapper.lookup(upstream_index) - } - - fn allocate_synthetic_output_index(&mut self) -> usize { - self.output_index_mapper.allocate_synthetic() - } - - fn next_output_index(&self) -> usize { - self.output_index_mapper.next_index() - } - - fn original_response_id(&self) -> Option<&str> { - self.original_response_id - .as_deref() - .or_else(|| self.accumulator.original_response_id()) - } - - fn snapshot_final_response(&self) -> Option { - self.accumulator.snapshot_final_response() - } - - /// Process an SSE event and determine what action to take - fn process_event(&mut self, event_name: Option<&str>, data: &str) -> StreamAction { - // Always feed to accumulator for storage - self.accumulator.ingest_block(&format!( - "{}data: {}", - event_name - .map(|n| format!("event: {}\n", n)) - .unwrap_or_default(), - data - )); - - let parsed: Value = match serde_json::from_str(data) { - Ok(v) => v, - Err(_) => return StreamAction::Forward, - }; - - let event_type = event_name - .map(|s| s.to_string()) - .or_else(|| { - parsed - .get("type") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()) - }) - .unwrap_or_default(); - - match event_type.as_str() { - event_types::RESPONSE_CREATED => { - if self.original_response_id.is_none() { - if let Some(response_obj) = parsed.get("response").and_then(|v| v.as_object()) { - if let Some(id) = response_obj.get("id").and_then(|v| v.as_str()) { - self.original_response_id = Some(id.to_string()); - } - } - } - StreamAction::Forward - } - event_types::RESPONSE_COMPLETED => StreamAction::Forward, - event_types::OUTPUT_ITEM_ADDED => { - if let Some(idx) = parsed.get("output_index").and_then(|v| v.as_u64()) { - self.ensure_output_index(idx as usize); - } - - // Check if this is a function_call item being added - if let Some(item) = parsed.get("item") { - if let Some(item_type) = item.get("type").and_then(|v| v.as_str()) { - if item_type == event_types::ITEM_TYPE_FUNCTION_CALL - || item_type == event_types::ITEM_TYPE_FUNCTION_TOOL_CALL - { - match parsed.get("output_index").and_then(|v| v.as_u64()) { - Some(idx) => { - let output_index = idx as usize; - let assigned_index = self.ensure_output_index(output_index); - let call_id = - item.get("call_id").and_then(|v| v.as_str()).unwrap_or(""); - let name = - item.get("name").and_then(|v| v.as_str()).unwrap_or(""); - - // Create or update the function call - let call = self.get_or_create_call(output_index, item); - call.call_id = call_id.to_string(); - call.name = name.to_string(); - call.assigned_output_index = Some(assigned_index); - - self.in_function_call = true; - } - None => { - tracing::warn!( - "Missing output_index in function_call added event, \ - forwarding without processing for tool execution" - ); - } - } - } - } - } - StreamAction::Forward - } - event_types::FUNCTION_CALL_ARGUMENTS_DELTA => { - // Accumulate arguments for the function call - if let Some(output_index) = parsed - .get("output_index") - .and_then(|v| v.as_u64()) - .map(|v| v as usize) - { - let assigned_index = self.ensure_output_index(output_index); - if let Some(delta) = parsed.get("delta").and_then(|v| v.as_str()) { - if let Some(call) = self - .pending_calls - .iter_mut() - .find(|c| c.output_index == output_index) - { - call.arguments_buffer.push_str(delta); - if let Some(obfuscation) = - parsed.get("obfuscation").and_then(|v| v.as_str()) - { - call.last_obfuscation = Some(obfuscation.to_string()); - } - if call.assigned_output_index.is_none() { - call.assigned_output_index = Some(assigned_index); - } - } - } - } - StreamAction::Forward - } - event_types::FUNCTION_CALL_ARGUMENTS_DONE => { - // Function call arguments complete - check if ready to execute - if let Some(output_index) = parsed - .get("output_index") - .and_then(|v| v.as_u64()) - .map(|v| v as usize) - { - let assigned_index = self.ensure_output_index(output_index); - if let Some(call) = self - .pending_calls - .iter_mut() - .find(|c| c.output_index == output_index) - { - if call.assigned_output_index.is_none() { - call.assigned_output_index = Some(assigned_index); - } - } - } - - if self.has_complete_calls() { - StreamAction::ExecuteTools - } else { - StreamAction::Forward - } - } - event_types::OUTPUT_ITEM_DELTA => self.process_output_delta(&parsed), - event_types::OUTPUT_ITEM_DONE => { - // Check if we have complete function calls ready to execute - if let Some(output_index) = parsed - .get("output_index") - .and_then(|v| v.as_u64()) - .map(|v| v as usize) - { - self.ensure_output_index(output_index); - } - - if self.has_complete_calls() { - StreamAction::ExecuteTools - } else { - StreamAction::Forward - } - } - _ => StreamAction::Forward, - } - } - - /// Process output delta events to detect and accumulate function calls - fn process_output_delta(&mut self, event: &Value) -> StreamAction { - let output_index = event - .get("output_index") - .and_then(|v| v.as_u64()) - .map(|v| v as usize) - .unwrap_or(0); - - let assigned_index = self.ensure_output_index(output_index); - - let delta = match event.get("delta") { - Some(d) => d, - None => return StreamAction::Forward, - }; - - // Check if this is a function call delta - let item_type = delta.get("type").and_then(|v| v.as_str()); - - if item_type == Some(event_types::ITEM_TYPE_FUNCTION_TOOL_CALL) - || item_type == Some(event_types::ITEM_TYPE_FUNCTION_CALL) - { - self.in_function_call = true; - - // Get or create function call for this output index - let call = self.get_or_create_call(output_index, delta); - call.assigned_output_index = Some(assigned_index); - - // Accumulate call_id if present - if let Some(call_id) = delta.get("call_id").and_then(|v| v.as_str()) { - call.call_id = call_id.to_string(); - } - - // Accumulate name if present - if let Some(name) = delta.get("name").and_then(|v| v.as_str()) { - call.name.push_str(name); - } - - // Accumulate arguments if present - if let Some(args) = delta.get("arguments").and_then(|v| v.as_str()) { - call.arguments_buffer.push_str(args); - } - - if let Some(obfuscation) = delta.get("obfuscation").and_then(|v| v.as_str()) { - call.last_obfuscation = Some(obfuscation.to_string()); - } - - // Buffer this event, don't forward to client - return StreamAction::Buffer; - } - - // Forward non-function-call events - StreamAction::Forward - } - - fn get_or_create_call( - &mut self, - output_index: usize, - delta: &Value, - ) -> &mut FunctionCallInProgress { - // Find existing call for this output index - // Note: We use position() + index instead of iter_mut().find() because we need - // to potentially mutate pending_calls after the early return, which causes - // borrow checker issues with the iter_mut approach - if let Some(pos) = self - .pending_calls - .iter() - .position(|c| c.output_index == output_index) - { - return &mut self.pending_calls[pos]; - } - - // Create new call - let call_id = delta - .get("call_id") - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(); - - let mut call = FunctionCallInProgress::new(call_id, output_index); - if let Some(obfuscation) = delta.get("obfuscation").and_then(|v| v.as_str()) { - call.last_obfuscation = Some(obfuscation.to_string()); - } - - self.pending_calls.push(call); - self.pending_calls - .last_mut() - .expect("Just pushed to pending_calls, must have at least one element") - } - - fn has_complete_calls(&self) -> bool { - !self.pending_calls.is_empty() && self.pending_calls.iter().all(|c| c.is_complete()) - } - - fn take_pending_calls(&mut self) -> Vec { - std::mem::take(&mut self.pending_calls) - } -} - -impl StreamingResponseAccumulator { - fn new() -> Self { - Self { - initial_response: None, - completed_response: None, - output_items: Vec::new(), - encountered_error: None, - } - } - - /// Feed the accumulator with the next SSE chunk. - fn ingest_block(&mut self, block: &str) { - if block.trim().is_empty() { - return; - } - self.process_block(block); - } - - /// Consume the accumulator and produce the best-effort final response value. - fn into_final_response(mut self) -> Option { - if self.completed_response.is_some() { - return self.completed_response; - } - - self.build_fallback_response() - } - - fn encountered_error(&self) -> Option<&Value> { - self.encountered_error.as_ref() - } - - fn original_response_id(&self) -> Option<&str> { - self.initial_response - .as_ref() - .and_then(|response| response.get("id")) - .and_then(|id| id.as_str()) - } - - fn snapshot_final_response(&self) -> Option { - if let Some(resp) = &self.completed_response { - return Some(resp.clone()); - } - self.build_fallback_response_snapshot() - } - - fn build_fallback_response_snapshot(&self) -> Option { - let mut response = self.initial_response.clone()?; - - if let Some(obj) = response.as_object_mut() { - obj.insert("status".to_string(), Value::String("completed".to_string())); - - let mut output_items = self.output_items.clone(); - output_items.sort_by_key(|(index, _)| *index); - let outputs: Vec = output_items.into_iter().map(|(_, item)| item).collect(); - obj.insert("output".to_string(), Value::Array(outputs)); - } - - Some(response) - } - - fn process_block(&mut self, block: &str) { - let trimmed = block.trim(); - if trimmed.is_empty() { - return; - } - - let mut event_name: Option = None; - let mut data_lines: Vec = Vec::new(); - - for line in trimmed.lines() { - if let Some(rest) = line.strip_prefix("event:") { - event_name = Some(rest.trim().to_string()); - } else if let Some(rest) = line.strip_prefix("data:") { - data_lines.push(rest.trim_start().to_string()); - } - } - - let data_payload = data_lines.join("\n"); - if data_payload.is_empty() { - return; - } - - self.handle_event(event_name.as_deref(), &data_payload); - } - - fn handle_event(&mut self, event_name: Option<&str>, data_payload: &str) { - let parsed: Value = match serde_json::from_str(data_payload) { - Ok(value) => value, - Err(err) => { - warn!("Failed to parse streaming event JSON: {}", err); - return; - } - }; - - let event_type = event_name - .map(|s| s.to_string()) - .or_else(|| { - parsed - .get("type") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()) - }) - .unwrap_or_default(); - - match event_type.as_str() { - event_types::RESPONSE_CREATED => { - if self.initial_response.is_none() { - if let Some(response) = parsed.get("response") { - self.initial_response = Some(response.clone()); - } - } - } - event_types::RESPONSE_COMPLETED => { - if let Some(response) = parsed.get("response") { - self.completed_response = Some(response.clone()); - } - } - event_types::OUTPUT_ITEM_DONE => { - if let (Some(index), Some(item)) = ( - parsed - .get("output_index") - .and_then(|v| v.as_u64()) - .map(|v| v as usize), - parsed.get("item"), - ) { - self.output_items.push((index, item.clone())); - } - } - "response.error" => { - self.encountered_error = Some(parsed); - } - _ => {} - } - } - - fn build_fallback_response(&mut self) -> Option { - let mut response = self.initial_response.clone()?; - - if let Some(obj) = response.as_object_mut() { - obj.insert("status".to_string(), Value::String("completed".to_string())); - - self.output_items.sort_by_key(|(index, _)| *index); - let outputs: Vec = self - .output_items - .iter() - .map(|(_, item)| item.clone()) - .collect(); - obj.insert("output".to_string(), Value::Array(outputs)); - } - - Some(response) - } -} - -impl OpenAIRouter { - // Maximum number of conversation items to attach as input when a conversation is provided - const MAX_CONVERSATION_HISTORY_ITEMS: usize = 100; - /// Create a new OpenAI router - pub async fn new( - base_url: String, - circuit_breaker_config: Option, - response_storage: SharedResponseStorage, - conversation_storage: SharedConversationStorage, - conversation_item_storage: SharedConversationItemStorage, - ) -> Result { - let client = reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(300)) - .build() - .map_err(|e| format!("Failed to create HTTP client: {}", e))?; - - let base_url = base_url.trim_end_matches('/').to_string(); - - // Convert circuit breaker config - let core_cb_config = circuit_breaker_config - .map(|cb| CoreCircuitBreakerConfig { - failure_threshold: cb.failure_threshold, - success_threshold: cb.success_threshold, - timeout_duration: std::time::Duration::from_secs(cb.timeout_duration_secs), - window_duration: std::time::Duration::from_secs(cb.window_duration_secs), - }) - .unwrap_or_default(); - - let circuit_breaker = CircuitBreaker::with_config(core_cb_config); - - // Optional MCP manager activation via env var path (config-driven gate) - let mcp_manager = match std::env::var("SGLANG_MCP_CONFIG").ok() { - Some(path) if !path.trim().is_empty() => { - match crate::mcp::McpConfig::from_file(&path).await { - Ok(cfg) => match crate::mcp::McpClientManager::new(cfg).await { - Ok(mgr) => Some(Arc::new(mgr)), - Err(err) => { - warn!("Failed to initialize MCP manager: {}", err); - None - } - }, - Err(err) => { - warn!("Failed to load MCP config from '{}': {}", path, err); - None - } - } - } - _ => None, - }; - - Ok(Self { - client, - base_url, - circuit_breaker, - healthy: AtomicBool::new(true), - response_storage, - conversation_storage, - conversation_item_storage, - mcp_manager, - }) - } - - async fn handle_non_streaming_response( - &self, - url: String, - headers: Option<&HeaderMap>, - mut payload: Value, - original_body: &ResponsesRequest, - original_previous_response_id: Option, - ) -> Response { - // Request-scoped MCP: build from request tools if provided; otherwise fall back to router-level MCP - let req_mcp_manager = Self::mcp_manager_from_request_tools(&original_body.tools).await; - let active_mcp = req_mcp_manager.as_ref().or(self.mcp_manager.as_ref()); - - // If the client requested MCP but we couldn't initialize it, fail early with a clear error - let requested_mcp = original_body - .tools - .iter() - .any(|t| matches!(t.r#type, ResponseToolType::Mcp)); - if requested_mcp && active_mcp.is_none() { - return ( - StatusCode::BAD_GATEWAY, - json!({ - "error": { - "message": "MCP server unavailable or failed to initialize from request tools", - "type": "mcp_unavailable", - "param": "tools", - } - }) - .to_string(), - ) - .into_response(); - } - - // If MCP is active, mirror one function tool into the outgoing payload - if let Some(mcp) = active_mcp { - if let Some(obj) = payload.as_object_mut() { - // Remove any non-function tools (e.g., custom "mcp" items) from outgoing payload - if let Some(v) = obj.get_mut("tools") { - if let Some(arr) = v.as_array_mut() { - arr.retain(|item| { - item.get("type") - .and_then(|v| v.as_str()) - .map(|s| s == "function") - .unwrap_or(false) - }); - if arr.is_empty() { - obj.remove("tools"); - obj.insert( - "tool_choice".to_string(), - Value::String("none".to_string()), - ); - } - } - } - // Build function tools for all discovered MCP tools - let mut tools_json = Vec::new(); - let tools = mcp.list_tools(); - for t in tools { - let parameters = t.parameters.clone().unwrap_or(serde_json::json!({ - "type": "object", - "properties": {}, - "additionalProperties": false - })); - let tool = serde_json::json!({ - "type": "function", - "name": t.name, - "description": t.description, - "parameters": parameters - }); - tools_json.push(tool); - } - if !tools_json.is_empty() { - obj.insert("tools".to_string(), Value::Array(tools_json)); - // Ensure tool_choice auto to allow model planning - obj.insert("tool_choice".to_string(), Value::String("auto".to_string())); - } - } - } - let request_builder = self.client.post(&url).json(&payload); - - // Apply headers with filtering - let request_builder = if let Some(headers) = headers { - apply_request_headers(headers, request_builder, true) - } else { - request_builder - }; - - match request_builder.send().await { - Ok(response) => { - let status = response.status(); - if !status.is_success() { - let error_text = response - .text() - .await - .unwrap_or_else(|e| format!("Failed to get error body: {}", e)); - return (status, error_text).into_response(); - } - - // Parse the response - match response.json::().await { - Ok(mut openai_response_json) => { - if let Some(prev_id) = original_previous_response_id { - if let Some(obj) = openai_response_json.as_object_mut() { - let should_insert = obj - .get("previous_response_id") - .map(|v| v.is_null()) - .unwrap_or(true); - if should_insert { - obj.insert( - "previous_response_id".to_string(), - Value::String(prev_id), - ); - } - } - } - - if let Some(obj) = openai_response_json.as_object_mut() { - if !obj.contains_key("instructions") { - if let Some(instructions) = &original_body.instructions { - obj.insert( - "instructions".to_string(), - Value::String(instructions.clone()), - ); - } - } - - if !obj.contains_key("metadata") { - if let Some(metadata) = &original_body.metadata { - let metadata_map: serde_json::Map = metadata - .iter() - .map(|(k, v)| (k.clone(), v.clone())) - .collect(); - obj.insert("metadata".to_string(), Value::Object(metadata_map)); - } - } - - // Reflect the client's requested store preference in the response body - obj.insert("store".to_string(), Value::Bool(original_body.store)); - } - - // If MCP is active and we detect a function call, enter the tool loop - let mut final_response_json = if let Some(mcp) = active_mcp { - if Self::extract_function_call(&openai_response_json).is_some() { - // Use the loop to handle potentially multiple tool calls - let loop_config = McpLoopConfig::default(); - match self - .execute_tool_loop( - &url, - headers, - payload.clone(), - original_body, - mcp, - &loop_config, - ) - .await - { - Ok(loop_result) => loop_result, - Err(err) => { - warn!("Tool loop failed: {}", err); - let error_body = json!({ - "error": { - "message": format!("Tool loop failed: {}", err), - "type": "internal_error", - } - }) - .to_string(); - return ( - StatusCode::INTERNAL_SERVER_ERROR, - [("content-type", "application/json")], - error_body, - ) - .into_response(); - } - } - } else { - // No function call detected, use response as-is - openai_response_json - } - } else { - openai_response_json - }; - - // Mask tools back to MCP format for client - Self::mask_tools_as_mcp(&mut final_response_json, original_body); - // Attach conversation id for client response if present (not forwarded upstream) - if let Some(conv_id) = original_body.conversation.clone() { - if let Some(obj) = final_response_json.as_object_mut() { - obj.insert("conversation".to_string(), json!({"id": conv_id})); - } - } - if original_body.store { - if let Err(e) = self - .store_response_internal(&final_response_json, original_body) - .await - { - warn!("Failed to store response: {}", e); - } - } - if let Some(conv_id) = original_body.conversation.clone() { - if let Err(err) = self - .persist_conversation_items( - &conv_id, - original_body, - &final_response_json, - ) - .await - { - warn!("Failed to persist conversation items: {}", err); - } - } - - match serde_json::to_string(&final_response_json) { - Ok(json_str) => ( - StatusCode::OK, - [("content-type", "application/json")], - json_str, - ) - .into_response(), - Err(e) => { - error!("Failed to serialize response: {}", e); - ( - StatusCode::INTERNAL_SERVER_ERROR, - json!({"error": {"message": "Failed to serialize response", "type": "internal_error"}}).to_string(), - ) - .into_response() - } - } - } - Err(e) => { - error!("Failed to parse OpenAI response: {}", e); - ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Failed to parse response: {}", e), - ) - .into_response() - } - } - } - Err(e) => ( - StatusCode::BAD_GATEWAY, - format!("Failed to forward request to OpenAI: {}", e), - ) - .into_response(), - } - } - - async fn persist_conversation_items( - &self, - conversation_id: &str, - original_body: &ResponsesRequest, - final_response_json: &Value, - ) -> Result<(), String> { - persist_items_with_storages( - self.conversation_storage.clone(), - self.conversation_item_storage.clone(), - conversation_id.to_string(), - original_body.clone(), - final_response_json.clone(), - ) - .await - } - - /// Build a request-scoped MCP manager from request tools, if present. - async fn mcp_manager_from_request_tools( - tools: &[ResponseTool], - ) -> Option> { - let tool = tools - .iter() - .find(|t| matches!(t.r#type, ResponseToolType::Mcp) && t.server_url.is_some())?; - let server_url = tool.server_url.as_ref()?.trim().to_string(); - if !(server_url.starts_with("http://") || server_url.starts_with("https://")) { - warn!( - "Ignoring MCP server_url with unsupported scheme: {}", - server_url - ); - return None; - } - let name = tool - .server_label - .clone() - .unwrap_or_else(|| "request-mcp".to_string()); - let token = tool.authorization.clone(); - let transport = if server_url.contains("/sse") { - crate::mcp::McpTransport::Sse { - url: server_url, - token, - } - } else { - crate::mcp::McpTransport::Streamable { - url: server_url, - token, - } - }; - let cfg = crate::mcp::McpConfig { - servers: vec![crate::mcp::McpServerConfig { name, transport }], - }; - match crate::mcp::McpClientManager::new(cfg).await { - Ok(mgr) => Some(Arc::new(mgr)), - Err(err) => { - warn!("Failed to initialize request-scoped MCP manager: {}", err); - None - } - } - } - - async fn handle_streaming_response( - &self, - url: String, - headers: Option<&HeaderMap>, - payload: Value, - original_body: &ResponsesRequest, - original_previous_response_id: Option, - ) -> Response { - // Check if MCP is active for this request - let req_mcp_manager = Self::mcp_manager_from_request_tools(&original_body.tools).await; - let active_mcp = req_mcp_manager.as_ref().or(self.mcp_manager.as_ref()); - - // If no MCP is active, use simple pass-through streaming - if active_mcp.is_none() { - return self - .handle_simple_streaming_passthrough( - url, - headers, - payload, - original_body, - original_previous_response_id, - ) - .await; - } - - let active_mcp = active_mcp.unwrap(); - - // MCP is active - transform tools and set up interception - self.handle_streaming_with_tool_interception( - url, - headers, - payload, - original_body, - original_previous_response_id, - active_mcp, - ) - .await - } - - /// Simple pass-through streaming without MCP interception - async fn handle_simple_streaming_passthrough( - &self, - url: String, - headers: Option<&HeaderMap>, - payload: Value, - original_body: &ResponsesRequest, - original_previous_response_id: Option, - ) -> Response { - let mut request_builder = self.client.post(&url).json(&payload); - - if let Some(headers) = headers { - request_builder = apply_request_headers(headers, request_builder, true); - } - - request_builder = request_builder.header("Accept", "text/event-stream"); - - let response = match request_builder.send().await { - Ok(resp) => resp, - Err(err) => { - self.circuit_breaker.record_failure(); - return ( - StatusCode::BAD_GATEWAY, - format!("Failed to forward request to OpenAI: {}", err), - ) - .into_response(); - } - }; - - let status = response.status(); - let status_code = - StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); - - if !status.is_success() { - self.circuit_breaker.record_failure(); - let error_body = match response.text().await { - Ok(body) => body, - Err(err) => format!("Failed to read upstream error body: {}", err), - }; - return (status_code, error_body).into_response(); - } - - self.circuit_breaker.record_success(); - - let preserved_headers = preserve_response_headers(response.headers()); - let mut upstream_stream = response.bytes_stream(); - - let (tx, rx) = mpsc::unbounded_channel::>(); - - let should_store = original_body.store; - let storage = self.response_storage.clone(); - let conv_storage = self.conversation_storage.clone(); - let conv_item_storage = self.conversation_item_storage.clone(); - let original_request = original_body.clone(); - let persist_needed = original_request.conversation.is_some(); - let previous_response_id = original_previous_response_id.clone(); - - tokio::spawn(async move { - let mut accumulator = StreamingResponseAccumulator::new(); - let mut upstream_failed = false; - let mut receiver_connected = true; - let mut pending = String::new(); - - while let Some(chunk_result) = upstream_stream.next().await { - match chunk_result { - Ok(chunk) => { - let chunk_text = match std::str::from_utf8(&chunk) { - Ok(text) => Cow::Borrowed(text), - Err(_) => Cow::Owned(String::from_utf8_lossy(&chunk).to_string()), - }; - - pending.push_str(&chunk_text.replace("\r\n", "\n")); - - while let Some(pos) = pending.find("\n\n") { - let raw_block = pending[..pos].to_string(); - pending.drain(..pos + 2); - - if raw_block.trim().is_empty() { - continue; - } - - let block_cow = if let Some(modified) = Self::rewrite_streaming_block( - raw_block.as_str(), - &original_request, - previous_response_id.as_deref(), - ) { - Cow::Owned(modified) - } else { - Cow::Borrowed(raw_block.as_str()) - }; - - if should_store || persist_needed { - accumulator.ingest_block(block_cow.as_ref()); - } - - if receiver_connected { - let chunk_to_send = format!("{}\n\n", block_cow); - if tx.send(Ok(Bytes::from(chunk_to_send))).is_err() { - receiver_connected = false; - } - } - - if !receiver_connected && !should_store { - break; - } - } - - if !receiver_connected && !should_store { - break; - } - } - Err(err) => { - upstream_failed = true; - let io_err = io::Error::other(err); - let _ = tx.send(Err(io_err)); - break; - } - } - } - - if (should_store || persist_needed) && !upstream_failed { - if !pending.trim().is_empty() { - accumulator.ingest_block(&pending); - } - let encountered_error = accumulator.encountered_error().cloned(); - if let Some(mut response_json) = accumulator.into_final_response() { - Self::patch_streaming_response_json( - &mut response_json, - &original_request, - previous_response_id.as_deref(), - ); - - if should_store { - if let Err(err) = - Self::store_response_impl(&storage, &response_json, &original_request) - .await - { - warn!("Failed to store streaming response: {}", err); - } - } - if persist_needed { - if let Some(conv_id) = original_request.conversation.clone() { - if let Err(err) = persist_items_with_storages( - conv_storage.clone(), - conv_item_storage.clone(), - conv_id, - original_request.clone(), - response_json.clone(), - ) - .await - { - warn!("Failed to persist conversation items (stream): {}", err); - } - } - } - } else if let Some(error_payload) = encountered_error { - warn!("Upstream streaming error payload: {}", error_payload); - } else { - warn!("Streaming completed without a final response payload"); - } - } - }); - - let body_stream = UnboundedReceiverStream::new(rx); - let mut response = Response::new(Body::from_stream(body_stream)); - *response.status_mut() = status_code; - - let headers_mut = response.headers_mut(); - for (name, value) in preserved_headers.iter() { - headers_mut.insert(name, value.clone()); - } - - if !headers_mut.contains_key(CONTENT_TYPE) { - headers_mut.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream")); - } - - response - } - - /// Apply all transformations to event data in-place (rewrite + transform) - /// Optimized to parse JSON only once instead of multiple times - /// Returns true if any changes were made - fn apply_event_transformations_inplace( - parsed_data: &mut Value, - server_label: &str, - original_request: &ResponsesRequest, - previous_response_id: Option<&str>, - ) -> bool { - let mut changed = false; - - // 1. Apply rewrite_streaming_block logic (store, previous_response_id, tools masking) - let event_type = parsed_data - .get("type") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()) - .unwrap_or_default(); - - let should_patch = matches!( - event_type.as_str(), - event_types::RESPONSE_CREATED - | event_types::RESPONSE_IN_PROGRESS - | event_types::RESPONSE_COMPLETED - ); - - if should_patch { - if let Some(response_obj) = parsed_data - .get_mut("response") - .and_then(|v| v.as_object_mut()) - { - let desired_store = Value::Bool(original_request.store); - if response_obj.get("store") != Some(&desired_store) { - response_obj.insert("store".to_string(), desired_store); - changed = true; - } - - if let Some(prev_id) = previous_response_id { - let needs_previous = response_obj - .get("previous_response_id") - .map(|v| v.is_null() || v.as_str().map(|s| s.is_empty()).unwrap_or(false)) - .unwrap_or(true); - - if needs_previous { - response_obj.insert( - "previous_response_id".to_string(), - Value::String(prev_id.to_string()), - ); - changed = true; - } - } - - // Mask tools from function to MCP format (optimized without cloning) - if response_obj.get("tools").is_some() { - let requested_mcp = original_request - .tools - .iter() - .any(|t| matches!(t.r#type, ResponseToolType::Mcp)); - - if requested_mcp { - if let Some(mcp_tools) = Self::build_mcp_tools_value(original_request) { - response_obj.insert("tools".to_string(), mcp_tools); - response_obj - .entry("tool_choice".to_string()) - .or_insert(Value::String("auto".to_string())); - changed = true; - } - } - } - } - } - - // 2. Apply transform_streaming_event logic (function_call → mcp_call) - match event_type.as_str() { - event_types::OUTPUT_ITEM_ADDED | event_types::OUTPUT_ITEM_DONE => { - if let Some(item) = parsed_data.get_mut("item") { - if let Some(item_type) = item.get("type").and_then(|v| v.as_str()) { - if item_type == event_types::ITEM_TYPE_FUNCTION_CALL - || item_type == event_types::ITEM_TYPE_FUNCTION_TOOL_CALL - { - item["type"] = json!(event_types::ITEM_TYPE_MCP_CALL); - item["server_label"] = json!(server_label); - - // Transform ID from fc_* to mcp_* - if let Some(id) = item.get("id").and_then(|v| v.as_str()) { - if let Some(stripped) = id.strip_prefix("fc_") { - let new_id = format!("mcp_{}", stripped); - item["id"] = json!(new_id); - } - } - - changed = true; - } - } - } - } - event_types::FUNCTION_CALL_ARGUMENTS_DONE => { - parsed_data["type"] = json!(event_types::MCP_CALL_ARGUMENTS_DONE); - - // Transform item_id from fc_* to mcp_* - if let Some(item_id) = parsed_data.get("item_id").and_then(|v| v.as_str()) { - if let Some(stripped) = item_id.strip_prefix("fc_") { - let new_id = format!("mcp_{}", stripped); - parsed_data["item_id"] = json!(new_id); - } - } - - changed = true; - } - _ => {} - } - - changed - } - - /// Forward and transform a streaming event to the client - /// Returns false if client disconnected - #[allow(clippy::too_many_arguments)] - fn forward_streaming_event( - raw_block: &str, - event_name: Option<&str>, - data: &str, - handler: &mut StreamingToolHandler, - tx: &mpsc::UnboundedSender>, - server_label: &str, - original_request: &ResponsesRequest, - previous_response_id: Option<&str>, - sequence_number: &mut u64, - ) -> bool { - // Skip individual function_call_arguments.delta events - we'll send them as one - if event_name == Some(event_types::FUNCTION_CALL_ARGUMENTS_DELTA) { - return true; - } - - // Parse JSON data once (optimized!) - let mut parsed_data: Value = match serde_json::from_str(data) { - Ok(v) => v, - Err(_) => { - // If parsing fails, forward raw block as-is - let chunk_to_send = format!("{}\n\n", raw_block); - return tx.send(Ok(Bytes::from(chunk_to_send))).is_ok(); - } - }; - - let event_type = event_name - .or_else(|| parsed_data.get("type").and_then(|v| v.as_str())) - .unwrap_or(""); - - if event_type == event_types::RESPONSE_COMPLETED { - return true; - } - - // Check if this is function_call_arguments.done - need to send buffered args first - let mut mapped_output_index: Option = None; - - if event_name == Some(event_types::FUNCTION_CALL_ARGUMENTS_DONE) { - if let Some(output_index) = parsed_data - .get("output_index") - .and_then(|v| v.as_u64()) - .map(|v| v as usize) - { - let assigned_index = handler - .mapped_output_index(output_index) - .unwrap_or(output_index); - mapped_output_index = Some(assigned_index); - - if let Some(call) = handler - .pending_calls - .iter() - .find(|c| c.output_index == output_index) - { - let arguments_value = if call.arguments_buffer.is_empty() { - "{}".to_string() - } else { - call.arguments_buffer.clone() - }; - - // Make sure the done event carries full arguments - parsed_data["arguments"] = Value::String(arguments_value.clone()); - - // Get item_id and transform it - let item_id = parsed_data - .get("item_id") - .and_then(|v| v.as_str()) - .unwrap_or(""); - let mcp_item_id = if let Some(stripped) = item_id.strip_prefix("fc_") { - format!("mcp_{}", stripped) - } else { - item_id.to_string() - }; - - // Emit a synthetic MCP arguments delta event before the done event - let mut delta_event = json!({ - "type": event_types::MCP_CALL_ARGUMENTS_DELTA, - "sequence_number": *sequence_number, - "output_index": assigned_index, - "item_id": mcp_item_id, - "delta": arguments_value, - }); - - if let Some(obfuscation) = call.last_obfuscation.as_ref() { - if let Some(obj) = delta_event.as_object_mut() { - obj.insert( - "obfuscation".to_string(), - Value::String(obfuscation.clone()), - ); - } - } else if let Some(obfuscation) = parsed_data.get("obfuscation").cloned() { - if let Some(obj) = delta_event.as_object_mut() { - obj.insert("obfuscation".to_string(), obfuscation); - } - } - - let delta_block = format!( - "event: {}\ndata: {}\n\n", - event_types::MCP_CALL_ARGUMENTS_DELTA, - delta_event - ); - if tx.send(Ok(Bytes::from(delta_block))).is_err() { - return false; - } - - *sequence_number += 1; - } - } - } - - // Remap output_index (if present) so downstream sees sequential indices - if mapped_output_index.is_none() { - if let Some(output_index) = parsed_data - .get("output_index") - .and_then(|v| v.as_u64()) - .map(|v| v as usize) - { - mapped_output_index = handler.mapped_output_index(output_index); - } - } - - if let Some(mapped) = mapped_output_index { - parsed_data["output_index"] = json!(mapped); - } - - // Apply all transformations in-place (single parse/serialize!) - Self::apply_event_transformations_inplace( - &mut parsed_data, - server_label, - original_request, - previous_response_id, - ); - - if let Some(response_obj) = parsed_data - .get_mut("response") - .and_then(|v| v.as_object_mut()) - { - if let Some(original_id) = handler.original_response_id() { - response_obj.insert("id".to_string(), Value::String(original_id.to_string())); - } - } - - // Update sequence number if present in the event - if parsed_data.get("sequence_number").is_some() { - parsed_data["sequence_number"] = json!(*sequence_number); - *sequence_number += 1; - } - - // Serialize once - let final_data = match serde_json::to_string(&parsed_data) { - Ok(s) => s, - Err(_) => { - // Serialization failed, forward original - let chunk_to_send = format!("{}\n\n", raw_block); - return tx.send(Ok(Bytes::from(chunk_to_send))).is_ok(); - } - }; - - // Rebuild SSE block with potentially transformed event name - let mut final_block = String::new(); - if let Some(evt) = event_name { - // Update event name for function_call_arguments events - if evt == event_types::FUNCTION_CALL_ARGUMENTS_DELTA { - final_block.push_str(&format!( - "event: {}\n", - event_types::MCP_CALL_ARGUMENTS_DELTA - )); - } else if evt == event_types::FUNCTION_CALL_ARGUMENTS_DONE { - final_block.push_str(&format!( - "event: {}\n", - event_types::MCP_CALL_ARGUMENTS_DONE - )); - } else { - final_block.push_str(&format!("event: {}\n", evt)); - } - } - final_block.push_str(&format!("data: {}", final_data)); - - let chunk_to_send = format!("{}\n\n", final_block); - if tx.send(Ok(Bytes::from(chunk_to_send))).is_err() { - return false; - } - - // After sending output_item.added for mcp_call, inject mcp_call.in_progress event - if event_name == Some(event_types::OUTPUT_ITEM_ADDED) { - if let Some(item) = parsed_data.get("item") { - if item.get("type").and_then(|v| v.as_str()) - == Some(event_types::ITEM_TYPE_MCP_CALL) - { - // Already transformed to mcp_call - if let (Some(item_id), Some(output_index)) = ( - item.get("id").and_then(|v| v.as_str()), - parsed_data.get("output_index").and_then(|v| v.as_u64()), - ) { - let in_progress_event = json!({ - "type": event_types::MCP_CALL_IN_PROGRESS, - "sequence_number": *sequence_number, - "output_index": output_index, - "item_id": item_id - }); - *sequence_number += 1; - let in_progress_block = format!( - "event: {}\ndata: {}\n\n", - event_types::MCP_CALL_IN_PROGRESS, - in_progress_event - ); - if tx.send(Ok(Bytes::from(in_progress_block))).is_err() { - return false; - } - } - } - } - } - - true - } - - /// Execute detected tool calls and send completion events to client - /// Returns false if client disconnected during execution - async fn execute_streaming_tool_calls( - pending_calls: Vec, - active_mcp: &Arc, - tx: &mpsc::UnboundedSender>, - state: &mut ToolLoopState, - server_label: &str, - sequence_number: &mut u64, - ) -> bool { - // Execute all pending tool calls (sequential, as PR3 is skipped) - for call in pending_calls { - // Skip if name is empty (invalid call) - if call.name.is_empty() { - warn!( - "Skipping incomplete tool call: name is empty, args_len={}", - call.arguments_buffer.len() - ); - continue; - } - - info!( - "Executing tool call during streaming: {} ({})", - call.name, call.call_id - ); - - // Use empty JSON object if arguments_buffer is empty - let args_str = if call.arguments_buffer.is_empty() { - "{}" - } else { - &call.arguments_buffer - }; - - let call_result = Self::execute_mcp_call(active_mcp, &call.name, args_str).await; - let (output_str, success, error_msg) = match call_result { - Ok((_, output)) => (output, true, None), - Err(err) => { - warn!("Tool execution failed during streaming: {}", err); - (json!({ "error": &err }).to_string(), false, Some(err)) - } - }; - - // Send mcp_call completion event to client - if !OpenAIRouter::send_mcp_call_completion_events_with_error( - tx, - &call, - &output_str, - server_label, - success, - error_msg.as_deref(), - sequence_number, - ) { - // Client disconnected, no point continuing tool execution - return false; - } - - // Record the call - state.record_call(call.call_id, call.name, call.arguments_buffer, output_str); - } - true - } - - /// Transform payload to replace MCP tools with function tools for streaming - fn prepare_mcp_payload_for_streaming( - payload: &mut Value, - active_mcp: &Arc, - ) { - if let Some(obj) = payload.as_object_mut() { - // Remove any non-function tools from outgoing payload - if let Some(v) = obj.get_mut("tools") { - if let Some(arr) = v.as_array_mut() { - arr.retain(|item| { - item.get("type") - .and_then(|v| v.as_str()) - .map(|s| s == event_types::ITEM_TYPE_FUNCTION) - .unwrap_or(false) - }); - } - } - - // Build function tools for all discovered MCP tools - let mut tools_json = Vec::new(); - let tools = active_mcp.list_tools(); - for t in tools { - let parameters = t.parameters.clone().unwrap_or(serde_json::json!({ - "type": "object", - "properties": {}, - "additionalProperties": false - })); - let tool = serde_json::json!({ - "type": event_types::ITEM_TYPE_FUNCTION, - "name": t.name, - "description": t.description, - "parameters": parameters - }); - tools_json.push(tool); - } - if !tools_json.is_empty() { - obj.insert("tools".to_string(), Value::Array(tools_json)); - obj.insert("tool_choice".to_string(), Value::String("auto".to_string())); - } - } - } - - /// Handle streaming WITH MCP tool call interception and execution - async fn handle_streaming_with_tool_interception( - &self, - url: String, - headers: Option<&HeaderMap>, - mut payload: Value, - original_body: &ResponsesRequest, - original_previous_response_id: Option, - active_mcp: &Arc, - ) -> Response { - // Transform MCP tools to function tools in payload - Self::prepare_mcp_payload_for_streaming(&mut payload, active_mcp); - - let (tx, rx) = mpsc::unbounded_channel::>(); - let should_store = original_body.store; - let storage = self.response_storage.clone(); - let conv_storage = self.conversation_storage.clone(); - let conv_item_storage = self.conversation_item_storage.clone(); - let original_request = original_body.clone(); - let persist_needed = original_request.conversation.is_some(); - let previous_response_id = original_previous_response_id.clone(); - - let client = self.client.clone(); - let url_clone = url.clone(); - let headers_opt = headers.cloned(); - let payload_clone = payload.clone(); - let active_mcp_clone = Arc::clone(active_mcp); - - // Spawn the streaming loop task - tokio::spawn(async move { - let mut state = ToolLoopState::new(original_request.input.clone()); - let loop_config = McpLoopConfig::default(); - let max_tool_calls = original_request.max_tool_calls.map(|n| n as usize); - let tools_json = payload_clone.get("tools").cloned().unwrap_or(json!([])); - let base_payload = payload_clone.clone(); - let mut current_payload = payload_clone; - let mut mcp_list_tools_sent = false; - let mut is_first_iteration = true; - let mut sequence_number: u64 = 0; // Track global sequence number across all iterations - let mut next_output_index: usize = 0; - let mut preserved_response_id: Option = None; - - let server_label = original_request - .tools - .iter() - .find(|t| matches!(t.r#type, ResponseToolType::Mcp)) - .and_then(|t| t.server_label.as_deref()) - .unwrap_or("mcp"); - - loop { - // Make streaming request - let mut request_builder = client.post(&url_clone).json(¤t_payload); - if let Some(ref h) = headers_opt { - request_builder = apply_request_headers(h, request_builder, true); - } - request_builder = request_builder.header("Accept", "text/event-stream"); - - let response = match request_builder.send().await { - Ok(r) => r, - Err(e) => { - let error_event = format!( - "event: error\ndata: {{\"error\": {{\"message\": \"{}\"}}}}\n\n", - e - ); - let _ = tx.send(Ok(Bytes::from(error_event))); - return; - } - }; - - if !response.status().is_success() { - let status = response.status(); - let body = response.text().await.unwrap_or_default(); - let error_event = format!("event: error\ndata: {{\"error\": {{\"message\": \"Upstream error {}: {}\"}}}}\n\n", status, body); - let _ = tx.send(Ok(Bytes::from(error_event))); - return; - } - - // Stream events and check for tool calls - let mut upstream_stream = response.bytes_stream(); - let mut handler = StreamingToolHandler::with_starting_index(next_output_index); - if let Some(ref id) = preserved_response_id { - handler.original_response_id = Some(id.clone()); - } - let mut pending = String::new(); - let mut tool_calls_detected = false; - let mut seen_in_progress = false; - - while let Some(chunk_result) = upstream_stream.next().await { - match chunk_result { - Ok(chunk) => { - let chunk_text = match std::str::from_utf8(&chunk) { - Ok(text) => Cow::Borrowed(text), - Err(_) => Cow::Owned(String::from_utf8_lossy(&chunk).to_string()), - }; - - pending.push_str(&chunk_text.replace("\r\n", "\n")); - - while let Some(pos) = pending.find("\n\n") { - let raw_block = pending[..pos].to_string(); - pending.drain(..pos + 2); - - if raw_block.trim().is_empty() { - continue; - } - - // Parse event - let (event_name, data) = Self::parse_sse_block(&raw_block); - - if data.is_empty() { - continue; - } - - // Process through handler - let action = handler.process_event(event_name, data.as_ref()); - - match action { - StreamAction::Forward => { - // Skip response.created and response.in_progress on subsequent iterations - // Do NOT consume their sequence numbers - we want continuous numbering - let should_skip = if !is_first_iteration { - if let Ok(parsed) = - serde_json::from_str::(data.as_ref()) - { - matches!( - parsed.get("type").and_then(|v| v.as_str()), - Some(event_types::RESPONSE_CREATED) - | Some(event_types::RESPONSE_IN_PROGRESS) - ) - } else { - false - } - } else { - false - }; - - if !should_skip { - // Forward the event - if !Self::forward_streaming_event( - &raw_block, - event_name, - data.as_ref(), - &mut handler, - &tx, - server_label, - &original_request, - previous_response_id.as_deref(), - &mut sequence_number, - ) { - // Client disconnected - return; - } - } - - // After forwarding response.in_progress, send mcp_list_tools events (once) - if !seen_in_progress { - if let Ok(parsed) = - serde_json::from_str::(data.as_ref()) - { - if parsed.get("type").and_then(|v| v.as_str()) - == Some(event_types::RESPONSE_IN_PROGRESS) - { - seen_in_progress = true; - if !mcp_list_tools_sent { - let list_tools_index = handler - .allocate_synthetic_output_index(); - if !OpenAIRouter::send_mcp_list_tools_events( - &tx, - &active_mcp_clone, - server_label, - list_tools_index, - &mut sequence_number, - ) { - // Client disconnected - return; - } - mcp_list_tools_sent = true; - } - } - } - } - } - StreamAction::Buffer => { - // Don't forward, just buffer - } - StreamAction::ExecuteTools => { - if !Self::forward_streaming_event( - &raw_block, - event_name, - data.as_ref(), - &mut handler, - &tx, - server_label, - &original_request, - previous_response_id.as_deref(), - &mut sequence_number, - ) { - // Client disconnected - return; - } - tool_calls_detected = true; - break; // Exit stream processing to execute tools - } - } - } - - if tool_calls_detected { - break; - } - } - Err(e) => { - let error_event = format!("event: error\ndata: {{\"error\": {{\"message\": \"Stream error: {}\"}}}}\n\n", e); - let _ = tx.send(Ok(Bytes::from(error_event))); - return; - } - } - } - - next_output_index = handler.next_output_index(); - if let Some(id) = handler.original_response_id().map(|s| s.to_string()) { - preserved_response_id = Some(id); - } - - // If no tool calls, we're done - stream is complete - if !tool_calls_detected { - if !Self::send_final_response_event( - &handler, - &tx, - &mut sequence_number, - &state, - Some(&active_mcp_clone), - &original_request, - previous_response_id.as_deref(), - server_label, - ) { - return; - } - - let final_response_json = if should_store || persist_needed { - handler.accumulator.into_final_response() - } else { - None - }; - - if let Some(mut response_json) = final_response_json { - if let Some(ref id) = preserved_response_id { - if let Some(obj) = response_json.as_object_mut() { - obj.insert("id".to_string(), Value::String(id.clone())); - } - } - Self::inject_mcp_metadata_streaming( - &mut response_json, - &state, - &active_mcp_clone, - server_label, - ); - - Self::mask_tools_as_mcp(&mut response_json, &original_request); - Self::patch_streaming_response_json( - &mut response_json, - &original_request, - previous_response_id.as_deref(), - ); - - if should_store { - if let Err(err) = Self::store_response_impl( - &storage, - &response_json, - &original_request, - ) - .await - { - warn!("Failed to store streaming response: {}", err); - } - } - - if persist_needed { - if let Some(conv_id) = original_request.conversation.clone() { - if let Err(err) = persist_items_with_storages( - conv_storage.clone(), - conv_item_storage.clone(), - conv_id, - original_request.clone(), - response_json.clone(), - ) - .await - { - warn!( - "Failed to persist conversation items (stream + MCP): {}", - err - ); - } - } - } - } - - let _ = tx.send(Ok(Bytes::from("data: [DONE]\n\n"))); - return; - } - - // Execute tools - let pending_calls = handler.take_pending_calls(); - - // Check iteration limit - state.iteration += 1; - state.total_calls += pending_calls.len(); - - let effective_limit = match max_tool_calls { - Some(user_max) => user_max.min(loop_config.max_iterations), - None => loop_config.max_iterations, - }; - - if state.total_calls > effective_limit { - warn!( - "Reached tool call limit during streaming: {}", - effective_limit - ); - let error_event = "event: error\ndata: {\"error\": {\"message\": \"Exceeded max_tool_calls limit\"}}\n\n".to_string(); - let _ = tx.send(Ok(Bytes::from(error_event))); - let _ = tx.send(Ok(Bytes::from("data: [DONE]\n\n"))); - return; - } - - // Execute all pending tool calls - if !Self::execute_streaming_tool_calls( - pending_calls, - &active_mcp_clone, - &tx, - &mut state, - server_label, - &mut sequence_number, - ) - .await - { - // Client disconnected during tool execution - return; - } - - // Build resume payload - match Self::build_resume_payload( - &base_payload, - &state.conversation_history, - &state.original_input, - &tools_json, - true, // is_streaming = true - ) { - Ok(resume_payload) => { - current_payload = resume_payload; - // Mark that we're no longer on the first iteration - is_first_iteration = false; - // Continue loop to make next streaming request - } - Err(e) => { - let error_event = format!("event: error\ndata: {{\"error\": {{\"message\": \"Failed to build resume payload: {}\"}}}}\n\n", e); - let _ = tx.send(Ok(Bytes::from(error_event))); - let _ = tx.send(Ok(Bytes::from("data: [DONE]\n\n"))); - return; - } - } - } - }); - - let body_stream = UnboundedReceiverStream::new(rx); - let mut response = Response::new(Body::from_stream(body_stream)); - *response.status_mut() = StatusCode::OK; - response - .headers_mut() - .insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream")); - response - } - - /// Parse an SSE block into event name and data - /// - /// Returns borrowed strings when possible to avoid allocations in hot paths. - /// Only allocates when multiple data lines need to be joined. - fn parse_sse_block(block: &str) -> (Option<&str>, Cow<'_, str>) { - let mut event_name: Option<&str> = None; - let mut data_lines: Vec<&str> = Vec::new(); - - for line in block.lines() { - if let Some(rest) = line.strip_prefix("event:") { - event_name = Some(rest.trim()); - } else if let Some(rest) = line.strip_prefix("data:") { - data_lines.push(rest.trim_start()); - } - } - - let data = if data_lines.len() == 1 { - Cow::Borrowed(data_lines[0]) - } else { - Cow::Owned(data_lines.join("\n")) - }; - - (event_name, data) - } - - // Note: transform_streaming_event has been replaced by apply_event_transformations_inplace - // which is more efficient (parses JSON only once instead of twice) - - /// Send mcp_list_tools events to client at the start of streaming - /// Returns false if client disconnected - fn send_mcp_list_tools_events( - tx: &mpsc::UnboundedSender>, - mcp: &Arc, - server_label: &str, - output_index: usize, - sequence_number: &mut u64, - ) -> bool { - let tools_item_full = Self::build_mcp_list_tools_item(mcp, server_label); - let item_id = tools_item_full - .get("id") - .and_then(|v| v.as_str()) - .unwrap_or(""); - - // Create empty tools version for the initial added event - let mut tools_item_empty = tools_item_full.clone(); - if let Some(obj) = tools_item_empty.as_object_mut() { - obj.insert("tools".to_string(), json!([])); - } - - // Event 1: response.output_item.added with empty tools - let event1_payload = json!({ - "type": event_types::OUTPUT_ITEM_ADDED, - "sequence_number": *sequence_number, - "output_index": output_index, - "item": tools_item_empty - }); - *sequence_number += 1; - let event1 = format!( - "event: {}\ndata: {}\n\n", - event_types::OUTPUT_ITEM_ADDED, - event1_payload - ); - if tx.send(Ok(Bytes::from(event1))).is_err() { - return false; // Client disconnected - } - - // Event 2: response.mcp_list_tools.in_progress - let event2_payload = json!({ - "type": event_types::MCP_LIST_TOOLS_IN_PROGRESS, - "sequence_number": *sequence_number, - "output_index": output_index, - "item_id": item_id - }); - *sequence_number += 1; - let event2 = format!( - "event: {}\ndata: {}\n\n", - event_types::MCP_LIST_TOOLS_IN_PROGRESS, - event2_payload - ); - if tx.send(Ok(Bytes::from(event2))).is_err() { - return false; - } - - // Event 3: response.mcp_list_tools.completed - let event3_payload = json!({ - "type": event_types::MCP_LIST_TOOLS_COMPLETED, - "sequence_number": *sequence_number, - "output_index": output_index, - "item_id": item_id - }); - *sequence_number += 1; - let event3 = format!( - "event: {}\ndata: {}\n\n", - event_types::MCP_LIST_TOOLS_COMPLETED, - event3_payload - ); - if tx.send(Ok(Bytes::from(event3))).is_err() { - return false; - } - - // Event 4: response.output_item.done with full tools list - let event4_payload = json!({ - "type": event_types::OUTPUT_ITEM_DONE, - "sequence_number": *sequence_number, - "output_index": output_index, - "item": tools_item_full - }); - *sequence_number += 1; - let event4 = format!( - "event: {}\ndata: {}\n\n", - event_types::OUTPUT_ITEM_DONE, - event4_payload - ); - tx.send(Ok(Bytes::from(event4))).is_ok() - } - - /// Send mcp_call completion events after tool execution - /// Returns false if client disconnected - fn send_mcp_call_completion_events_with_error( - tx: &mpsc::UnboundedSender>, - call: &FunctionCallInProgress, - output: &str, - server_label: &str, - success: bool, - error_msg: Option<&str>, - sequence_number: &mut u64, - ) -> bool { - let effective_output_index = call.effective_output_index(); - - // Build mcp_call item (reuse existing function) - let mcp_call_item = Self::build_mcp_call_item( - &call.name, - &call.arguments_buffer, - output, - server_label, - success, - error_msg, - ); - - // Get the mcp_call item_id - let item_id = mcp_call_item - .get("id") - .and_then(|v| v.as_str()) - .unwrap_or(""); - - // Event 1: response.mcp_call.completed - let completed_payload = json!({ - "type": event_types::MCP_CALL_COMPLETED, - "sequence_number": *sequence_number, - "output_index": effective_output_index, - "item_id": item_id - }); - *sequence_number += 1; - - let completed_event = format!( - "event: {}\ndata: {}\n\n", - event_types::MCP_CALL_COMPLETED, - completed_payload - ); - if tx.send(Ok(Bytes::from(completed_event))).is_err() { - return false; - } - - // Event 2: response.output_item.done (with completed mcp_call) - let done_payload = json!({ - "type": event_types::OUTPUT_ITEM_DONE, - "sequence_number": *sequence_number, - "output_index": effective_output_index, - "item": mcp_call_item - }); - *sequence_number += 1; - - let done_event = format!( - "event: {}\ndata: {}\n\n", - event_types::OUTPUT_ITEM_DONE, - done_payload - ); - tx.send(Ok(Bytes::from(done_event))).is_ok() - } - - #[allow(clippy::too_many_arguments)] - fn send_final_response_event( - handler: &StreamingToolHandler, - tx: &mpsc::UnboundedSender>, - sequence_number: &mut u64, - state: &ToolLoopState, - active_mcp: Option<&Arc>, - original_request: &ResponsesRequest, - previous_response_id: Option<&str>, - server_label: &str, - ) -> bool { - let mut final_response = match handler.snapshot_final_response() { - Some(resp) => resp, - None => { - warn!("Final response snapshot unavailable; skipping synthetic completion event"); - return true; - } - }; - - if let Some(original_id) = handler.original_response_id() { - if let Some(obj) = final_response.as_object_mut() { - obj.insert("id".to_string(), Value::String(original_id.to_string())); - } - } - - if let Some(mcp) = active_mcp { - Self::inject_mcp_metadata_streaming(&mut final_response, state, mcp, server_label); - } - - Self::mask_tools_as_mcp(&mut final_response, original_request); - Self::patch_streaming_response_json( - &mut final_response, - original_request, - previous_response_id, - ); - - if let Some(obj) = final_response.as_object_mut() { - obj.insert("status".to_string(), Value::String("completed".to_string())); - } - - let completed_payload = json!({ - "type": event_types::RESPONSE_COMPLETED, - "sequence_number": *sequence_number, - "response": final_response - }); - *sequence_number += 1; - - let completed_event = format!( - "event: {}\ndata: {}\n\n", - event_types::RESPONSE_COMPLETED, - completed_payload - ); - tx.send(Ok(Bytes::from(completed_event))).is_ok() - } - - /// Inject MCP metadata into a streaming response - fn inject_mcp_metadata_streaming( - response: &mut Value, - state: &ToolLoopState, - mcp: &Arc, - server_label: &str, - ) { - if let Some(output_array) = response.get_mut("output").and_then(|v| v.as_array_mut()) { - output_array.retain(|item| { - item.get("type").and_then(|t| t.as_str()) - != Some(event_types::ITEM_TYPE_MCP_LIST_TOOLS) - }); - - let list_tools_item = Self::build_mcp_list_tools_item(mcp, server_label); - output_array.insert(0, list_tools_item); - - let mcp_call_items = - Self::build_executed_mcp_call_items(&state.conversation_history, server_label); - let mut insert_pos = 1; - for item in mcp_call_items { - output_array.insert(insert_pos, item); - insert_pos += 1; - } - } else if let Some(obj) = response.as_object_mut() { - let mut output_items = Vec::new(); - output_items.push(Self::build_mcp_list_tools_item(mcp, server_label)); - output_items.extend(Self::build_executed_mcp_call_items( - &state.conversation_history, - server_label, - )); - obj.insert("output".to_string(), Value::Array(output_items)); - } - } - - async fn store_response_internal( - &self, - response_json: &Value, - original_body: &ResponsesRequest, - ) -> Result<(), String> { - if !original_body.store { - return Ok(()); - } - - match Self::store_response_impl(&self.response_storage, response_json, original_body).await - { - Ok(response_id) => { - info!(response_id = %response_id.0, "Stored response locally"); - Ok(()) - } - Err(e) => Err(e), - } - } - - async fn store_response_impl( - response_storage: &SharedResponseStorage, - response_json: &Value, - original_body: &ResponsesRequest, - ) -> Result { - let input_text = match &original_body.input { - ResponseInput::Text(text) => text.clone(), - ResponseInput::Items(_) => "complex input".to_string(), - }; - - let output_text = Self::extract_primary_output_text(response_json).unwrap_or_default(); - - let mut stored_response = StoredResponse::new(input_text, output_text, None); - - stored_response.instructions = response_json - .get("instructions") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()) - .or_else(|| original_body.instructions.clone()); - - stored_response.model = response_json - .get("model") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()) - .or_else(|| original_body.model.clone()); - - stored_response.user = response_json - .get("user") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()) - .or_else(|| original_body.user.clone()); - - // Set conversation id from request if provided - if let Some(conv_id) = original_body.conversation.clone() { - stored_response.conversation_id = Some(conv_id); - } - - stored_response.metadata = response_json - .get("metadata") - .and_then(|v| v.as_object()) - .map(|m| { - m.iter() - .map(|(k, v)| (k.clone(), v.clone())) - .collect::>() - }) - .unwrap_or_else(|| original_body.metadata.clone().unwrap_or_default()); - - stored_response.previous_response_id = response_json - .get("previous_response_id") - .and_then(|v| v.as_str()) - .map(ResponseId::from) - .or_else(|| { - original_body - .previous_response_id - .as_ref() - .map(|id| ResponseId::from(id.as_str())) - }); - - if let Some(id_str) = response_json.get("id").and_then(|v| v.as_str()) { - stored_response.id = ResponseId::from(id_str); - } - - stored_response.raw_response = response_json.clone(); - - response_storage - .store_response(stored_response) - .await - .map_err(|e| format!("Failed to store response: {}", e)) - } - - fn patch_streaming_response_json( - response_json: &mut Value, - original_body: &ResponsesRequest, - original_previous_response_id: Option<&str>, - ) { - if let Some(obj) = response_json.as_object_mut() { - if let Some(prev_id) = original_previous_response_id { - let should_insert = obj - .get("previous_response_id") - .map(|v| v.is_null() || v.as_str().map(|s| s.is_empty()).unwrap_or(false)) - .unwrap_or(true); - if should_insert { - obj.insert( - "previous_response_id".to_string(), - Value::String(prev_id.to_string()), - ); - } - } - - if !obj.contains_key("instructions") - || obj - .get("instructions") - .map(|v| v.is_null()) - .unwrap_or(false) - { - if let Some(instructions) = &original_body.instructions { - obj.insert( - "instructions".to_string(), - Value::String(instructions.clone()), - ); - } - } - - if !obj.contains_key("metadata") - || obj.get("metadata").map(|v| v.is_null()).unwrap_or(false) - { - if let Some(metadata) = &original_body.metadata { - let metadata_map: serde_json::Map = metadata - .iter() - .map(|(k, v)| (k.clone(), v.clone())) - .collect(); - obj.insert("metadata".to_string(), Value::Object(metadata_map)); - } - } - - obj.insert("store".to_string(), Value::Bool(original_body.store)); - - if obj - .get("model") - .and_then(|v| v.as_str()) - .map(|s| s.is_empty()) - .unwrap_or(true) - { - if let Some(model) = &original_body.model { - obj.insert("model".to_string(), Value::String(model.clone())); - } - } - - if obj.get("user").map(|v| v.is_null()).unwrap_or(false) { - if let Some(user) = &original_body.user { - obj.insert("user".to_string(), Value::String(user.clone())); - } - } - - // Attach conversation id for client response if present (final aggregated JSON) - if let Some(conv_id) = original_body.conversation.clone() { - obj.insert("conversation".to_string(), json!({"id": conv_id})); - } - } - } - - fn rewrite_streaming_block( - block: &str, - original_body: &ResponsesRequest, - original_previous_response_id: Option<&str>, - ) -> Option { - let trimmed = block.trim(); - if trimmed.is_empty() { - return None; - } - - let mut data_lines: Vec = Vec::new(); - - for line in trimmed.lines() { - if line.starts_with("data:") { - data_lines.push(line.trim_start_matches("data:").trim_start().to_string()); - } - } - - if data_lines.is_empty() { - return None; - } - - let payload = data_lines.join("\n"); - let mut parsed: Value = match serde_json::from_str(&payload) { - Ok(value) => value, - Err(err) => { - warn!("Failed to parse streaming JSON payload: {}", err); - return None; - } - }; - - let event_type = parsed - .get("type") - .and_then(|v| v.as_str()) - .unwrap_or_default(); - - let should_patch = matches!( - event_type, - event_types::RESPONSE_CREATED - | event_types::RESPONSE_IN_PROGRESS - | event_types::RESPONSE_COMPLETED - ); - - if !should_patch { - return None; - } - - let mut changed = false; - if let Some(response_obj) = parsed.get_mut("response").and_then(|v| v.as_object_mut()) { - let desired_store = Value::Bool(original_body.store); - if response_obj.get("store") != Some(&desired_store) { - response_obj.insert("store".to_string(), desired_store); - changed = true; - } - - if let Some(prev_id) = original_previous_response_id { - let needs_previous = response_obj - .get("previous_response_id") - .map(|v| v.is_null() || v.as_str().map(|s| s.is_empty()).unwrap_or(false)) - .unwrap_or(true); - - if needs_previous { - response_obj.insert( - "previous_response_id".to_string(), - Value::String(prev_id.to_string()), - ); - changed = true; - } - } - - // Attach conversation id into streaming event response content with ordering - if let Some(conv_id) = original_body.conversation.clone() { - response_obj.insert("conversation".to_string(), json!({"id": conv_id})); - changed = true; - } - } - - if !changed { - return None; - } - - let new_payload = match serde_json::to_string(&parsed) { - Ok(json) => json, - Err(err) => { - warn!("Failed to serialize modified streaming payload: {}", err); - return None; - } - }; - - let mut rebuilt_lines = Vec::new(); - let mut data_written = false; - for line in trimmed.lines() { - if line.starts_with("data:") { - if !data_written { - rebuilt_lines.push(format!("data: {}", new_payload)); - data_written = true; - } - } else { - rebuilt_lines.push(line.to_string()); - } - } - - if !data_written { - rebuilt_lines.push(format!("data: {}", new_payload)); - } - - Some(rebuilt_lines.join("\n")) - } - fn extract_primary_output_text(response_json: &Value) -> Option { - if let Some(items) = response_json.get("output").and_then(|v| v.as_array()) { - for item in items { - if let Some(content) = item.get("content").and_then(|v| v.as_array()) { - for part in content { - if part - .get("type") - .and_then(|v| v.as_str()) - .map(|t| t == "output_text") - .unwrap_or(false) - { - if let Some(text) = part.get("text").and_then(|v| v.as_str()) { - return Some(text.to_string()); - } - } - } - } - } - } - - None - } -} - -impl OpenAIRouter { - fn extract_function_call(resp: &Value) -> Option<(String, String, String)> { - let output = resp.get("output")?.as_array()?; - for item in output { - let obj = item.as_object()?; - let t = obj.get("type")?.as_str()?; - if t == event_types::ITEM_TYPE_FUNCTION_TOOL_CALL - || t == event_types::ITEM_TYPE_FUNCTION_CALL - { - let call_id = obj - .get("call_id") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()) - .or_else(|| { - obj.get("id") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()) - })?; - let name = obj.get("name")?.as_str()?.to_string(); - let arguments = obj.get("arguments")?.as_str()?.to_string(); - return Some((call_id, name, arguments)); - } - } - None - } - - /// Replace returned tools with the original request's MCP tool block (if present) so - /// external clients see MCP semantics rather than internal function tools. - /// Build MCP tools array value without cloning entire response object - fn build_mcp_tools_value(original_body: &ResponsesRequest) -> Option { - let mcp_tool = original_body - .tools - .iter() - .find(|t| matches!(t.r#type, ResponseToolType::Mcp) && t.server_url.is_some())?; - - let mut m = serde_json::Map::new(); - m.insert("type".to_string(), Value::String("mcp".to_string())); - if let Some(label) = &mcp_tool.server_label { - m.insert("server_label".to_string(), Value::String(label.clone())); - } - if let Some(url) = &mcp_tool.server_url { - m.insert("server_url".to_string(), Value::String(url.clone())); - } - if let Some(desc) = &mcp_tool.server_description { - m.insert( - "server_description".to_string(), - Value::String(desc.clone()), - ); - } - if let Some(req) = &mcp_tool.require_approval { - m.insert("require_approval".to_string(), Value::String(req.clone())); - } - if let Some(allowed) = &mcp_tool.allowed_tools { - m.insert( - "allowed_tools".to_string(), - Value::Array(allowed.iter().map(|s| Value::String(s.clone())).collect()), - ); - } - - Some(Value::Array(vec![Value::Object(m)])) - } - - fn mask_tools_as_mcp(resp: &mut Value, original_body: &ResponsesRequest) { - let mcp_tool = original_body - .tools - .iter() - .find(|t| matches!(t.r#type, ResponseToolType::Mcp) && t.server_url.is_some()); - let Some(t) = mcp_tool else { - return; - }; - - let mut m = serde_json::Map::new(); - m.insert("type".to_string(), Value::String("mcp".to_string())); - if let Some(label) = &t.server_label { - m.insert("server_label".to_string(), Value::String(label.clone())); - } - if let Some(url) = &t.server_url { - m.insert("server_url".to_string(), Value::String(url.clone())); - } - if let Some(desc) = &t.server_description { - m.insert( - "server_description".to_string(), - Value::String(desc.clone()), - ); - } - if let Some(req) = &t.require_approval { - m.insert("require_approval".to_string(), Value::String(req.clone())); - } - if let Some(allowed) = &t.allowed_tools { - m.insert( - "allowed_tools".to_string(), - Value::Array(allowed.iter().map(|s| Value::String(s.clone())).collect()), - ); - } - - if let Some(obj) = resp.as_object_mut() { - obj.insert("tools".to_string(), Value::Array(vec![Value::Object(m)])); - obj.entry("tool_choice") - .or_insert(Value::String("auto".to_string())); - } - } - - async fn execute_mcp_call( - mcp_mgr: &Arc, - tool_name: &str, - args_json_str: &str, - ) -> Result<(String, String), String> { - let args_value: Value = - serde_json::from_str(args_json_str).map_err(|e| format!("parse tool args: {}", e))?; - let args_obj = args_value.as_object().cloned(); - - let server_name = mcp_mgr - .get_tool(tool_name) - .map(|t| t.server) - .ok_or_else(|| format!("tool not found: {}", tool_name))?; - - let result = mcp_mgr - .call_tool(tool_name, args_obj) - .await - .map_err(|e| format!("tool call failed: {}", e))?; - - let output_str = serde_json::to_string(&result) - .map_err(|e| format!("Failed to serialize tool result: {}", e))?; - Ok((server_name, output_str)) - } - - /// Build a resume payload with conversation history - fn build_resume_payload( - base_payload: &Value, - conversation_history: &[Value], - original_input: &ResponseInput, - tools_json: &Value, - is_streaming: bool, - ) -> Result { - // Clone the base payload which already has cleaned fields - let mut payload = base_payload.clone(); - - let obj = payload - .as_object_mut() - .ok_or_else(|| "payload not an object".to_string())?; - - // Build input array: start with original user input - let mut input_array = Vec::new(); - - // Add original user message - // For structured input, serialize the original input items - match original_input { - ResponseInput::Text(text) => { - let user_item = json!({ - "type": "message", - "role": "user", - "content": [{ "type": "input_text", "text": text }] - }); - input_array.push(user_item); - } - ResponseInput::Items(items) => { - // Items are already structured ResponseInputOutputItem, convert to JSON - if let Ok(items_value) = to_value(items) { - if let Some(items_arr) = items_value.as_array() { - input_array.extend_from_slice(items_arr); - } - } - } - } - - // Add all conversation history (function calls and outputs) - input_array.extend_from_slice(conversation_history); - - obj.insert("input".to_string(), Value::Array(input_array)); - - // Use the transformed tools (function tools, not MCP tools) - if let Some(tools_arr) = tools_json.as_array() { - if !tools_arr.is_empty() { - obj.insert("tools".to_string(), tools_json.clone()); - } - } - - // Set streaming mode based on caller's context - obj.insert("stream".to_string(), Value::Bool(is_streaming)); - obj.insert("store".to_string(), Value::Bool(false)); - - // Note: SGLang-specific fields were already removed from base_payload - // before it was passed to execute_tool_loop (see route_responses lines 1935-1946) - - Ok(payload) - } - - /// Helper function to build mcp_call items from executed tool calls in conversation history - fn build_executed_mcp_call_items( - conversation_history: &[Value], - server_label: &str, - ) -> Vec { - let mut mcp_call_items = Vec::new(); - - for item in conversation_history { - if item.get("type").and_then(|t| t.as_str()) - == Some(event_types::ITEM_TYPE_FUNCTION_CALL) - { - let call_id = item.get("call_id").and_then(|v| v.as_str()).unwrap_or(""); - let tool_name = item.get("name").and_then(|v| v.as_str()).unwrap_or(""); - let args = item - .get("arguments") - .and_then(|v| v.as_str()) - .unwrap_or("{}"); - - // Find corresponding output - let output_item = conversation_history.iter().find(|o| { - o.get("type").and_then(|t| t.as_str()) == Some("function_call_output") - && o.get("call_id").and_then(|c| c.as_str()) == Some(call_id) - }); - - let output_str = output_item - .and_then(|o| o.get("output").and_then(|v| v.as_str())) - .unwrap_or("{}"); - - // Check if output contains error by parsing JSON - let is_error = serde_json::from_str::(output_str) - .map(|v| v.get("error").is_some()) - .unwrap_or(false); - - let mcp_call_item = Self::build_mcp_call_item( - tool_name, - args, - output_str, - server_label, - !is_error, - if is_error { - Some("Tool execution failed") - } else { - None - }, - ); - mcp_call_items.push(mcp_call_item); - } - } - - mcp_call_items - } - - /// Build an incomplete response when limits are exceeded - fn build_incomplete_response( - mut response: Value, - state: ToolLoopState, - reason: &str, - active_mcp: &Arc, - original_body: &ResponsesRequest, - ) -> Result { - let obj = response - .as_object_mut() - .ok_or_else(|| "response not an object".to_string())?; - - // Set status to completed (not failed - partial success) - obj.insert("status".to_string(), Value::String("completed".to_string())); - - // Set incomplete_details - obj.insert( - "incomplete_details".to_string(), - json!({ "reason": reason }), - ); - - // Convert any function_call in output to mcp_call format - if let Some(output_array) = obj.get_mut("output").and_then(|v| v.as_array_mut()) { - let server_label = original_body - .tools - .iter() - .find(|t| matches!(t.r#type, ResponseToolType::Mcp)) - .and_then(|t| t.server_label.as_deref()) - .unwrap_or("mcp"); - - // Find any function_call items and convert them to mcp_call (incomplete) - let mut mcp_call_items = Vec::new(); - for item in output_array.iter() { - if item.get("type").and_then(|t| t.as_str()) - == Some(event_types::ITEM_TYPE_FUNCTION_TOOL_CALL) - { - let tool_name = item.get("name").and_then(|v| v.as_str()).unwrap_or(""); - let args = item - .get("arguments") - .and_then(|v| v.as_str()) - .unwrap_or("{}"); - - // Mark as incomplete - not executed - let mcp_call_item = Self::build_mcp_call_item( - tool_name, - args, - "", // No output - wasn't executed - server_label, - false, // Not successful - Some("Not executed - response stopped due to limit"), - ); - mcp_call_items.push(mcp_call_item); - } - } - - // Add mcp_list_tools and executed mcp_call items at the beginning - if state.total_calls > 0 || !mcp_call_items.is_empty() { - let list_tools_item = Self::build_mcp_list_tools_item(active_mcp, server_label); - output_array.insert(0, list_tools_item); - - // Add mcp_call items for executed calls using helper - let executed_items = - Self::build_executed_mcp_call_items(&state.conversation_history, server_label); - - let mut insert_pos = 1; - for item in executed_items { - output_array.insert(insert_pos, item); - insert_pos += 1; - } - - // Add incomplete mcp_call items - for item in mcp_call_items { - output_array.insert(insert_pos, item); - insert_pos += 1; - } - } - } - - // Add warning to metadata - if let Some(metadata_val) = obj.get_mut("metadata") { - if let Some(metadata_obj) = metadata_val.as_object_mut() { - if let Some(mcp_val) = metadata_obj.get_mut("mcp") { - if let Some(mcp_obj) = mcp_val.as_object_mut() { - mcp_obj.insert( - "truncation_warning".to_string(), - Value::String(format!( - "Loop terminated at {} iterations, {} total calls (reason: {})", - state.iteration, state.total_calls, reason - )), - ); - } - } - } - } - - Ok(response) - } - - /// Execute the tool calling loop - async fn execute_tool_loop( - &self, - url: &str, - headers: Option<&HeaderMap>, - initial_payload: Value, - original_body: &ResponsesRequest, - active_mcp: &Arc, - config: &McpLoopConfig, - ) -> Result { - let mut state = ToolLoopState::new(original_body.input.clone()); - - // Get max_tool_calls from request (None means no user-specified limit) - let max_tool_calls = original_body.max_tool_calls.map(|n| n as usize); - - // Keep initial_payload as base template (already has fields cleaned) - let base_payload = initial_payload.clone(); - let tools_json = base_payload.get("tools").cloned().unwrap_or(json!([])); - let mut current_payload = initial_payload; - - info!( - "Starting tool loop: max_tool_calls={:?}, max_iterations={}", - max_tool_calls, config.max_iterations - ); - - loop { - // Make request to upstream - let request_builder = self.client.post(url).json(¤t_payload); - let request_builder = if let Some(headers) = headers { - apply_request_headers(headers, request_builder, true) - } else { - request_builder - }; - - let response = request_builder - .send() - .await - .map_err(|e| format!("upstream request failed: {}", e))?; - - if !response.status().is_success() { - let status = response.status(); - let body = response.text().await.unwrap_or_default(); - return Err(format!("upstream error {}: {}", status, body)); - } - - let mut response_json = response - .json::() - .await - .map_err(|e| format!("parse response: {}", e))?; - - // Check for function call - if let Some((call_id, tool_name, args_json_str)) = - Self::extract_function_call(&response_json) - { - state.iteration += 1; - state.total_calls += 1; - - info!( - "Tool loop iteration {}: calling {} (call_id: {})", - state.iteration, tool_name, call_id - ); - - // Check combined limit: use minimum of user's max_tool_calls (if set) and safety max_iterations - let effective_limit = match max_tool_calls { - Some(user_max) => user_max.min(config.max_iterations), - None => config.max_iterations, - }; - - if state.total_calls > effective_limit { - if let Some(user_max) = max_tool_calls { - if state.total_calls > user_max { - warn!("Reached user-specified max_tool_calls limit: {}", user_max); - } else { - warn!( - "Reached safety max_iterations limit: {}", - config.max_iterations - ); - } - } else { - warn!( - "Reached safety max_iterations limit: {}", - config.max_iterations - ); - } - - return Self::build_incomplete_response( - response_json, - state, - "max_tool_calls", - active_mcp, - original_body, - ); - } - - // Execute tool - let call_result = - Self::execute_mcp_call(active_mcp, &tool_name, &args_json_str).await; - - let output_str = match call_result { - Ok((_, output)) => output, - Err(err) => { - warn!("Tool execution failed: {}", err); - // Return error as output, let model decide how to proceed - json!({ "error": err }).to_string() - } - }; - - // Record the call - state.record_call(call_id, tool_name, args_json_str, output_str); - - // Build resume payload - current_payload = Self::build_resume_payload( - &base_payload, - &state.conversation_history, - &state.original_input, - &tools_json, - false, // is_streaming = false (non-streaming tool loop) - )?; - } else { - // No more tool calls, we're done - info!( - "Tool loop completed: {} iterations, {} total calls", - state.iteration, state.total_calls - ); - - // Inject MCP output items if we executed any tools - if state.total_calls > 0 { - let server_label = original_body - .tools - .iter() - .find(|t| matches!(t.r#type, ResponseToolType::Mcp)) - .and_then(|t| t.server_label.as_deref()) - .unwrap_or("mcp"); - - // Build mcp_list_tools item - let list_tools_item = Self::build_mcp_list_tools_item(active_mcp, server_label); - - // Insert at beginning of output array - if let Some(output_array) = response_json - .get_mut("output") - .and_then(|v| v.as_array_mut()) - { - output_array.insert(0, list_tools_item); - - // Build mcp_call items using helper function - let mcp_call_items = Self::build_executed_mcp_call_items( - &state.conversation_history, - server_label, - ); - - // Insert mcp_call items after mcp_list_tools using mutable position - let mut insert_pos = 1; - for item in mcp_call_items { - output_array.insert(insert_pos, item); - insert_pos += 1; - } - } - } - - return Ok(response_json); - } - } - } - - /// Generate a unique ID for MCP output items (similar to OpenAI format) - fn generate_mcp_id(prefix: &str) -> String { - use rand::RngCore; - let mut rng = rand::rng(); - let mut bytes = [0u8; 30]; - rng.fill_bytes(&mut bytes); - let hex_string: String = bytes.iter().map(|b| format!("{:02x}", b)).collect(); - format!("{}_{}", prefix, hex_string) - } - - /// Build an mcp_list_tools output item - fn build_mcp_list_tools_item( - mcp: &Arc, - server_label: &str, - ) -> Value { - let tools = mcp.list_tools(); - let tools_json: Vec = tools - .iter() - .map(|t| { - json!({ - "name": t.name, - "description": t.description, - "input_schema": t.parameters.clone().unwrap_or_else(|| json!({ - "type": "object", - "properties": {}, - "additionalProperties": false - })), - "annotations": { - "read_only": false - } - }) - }) - .collect(); - - json!({ - "id": Self::generate_mcp_id("mcpl"), - "type": event_types::ITEM_TYPE_MCP_LIST_TOOLS, - "server_label": server_label, - "tools": tools_json - }) - } - - /// Build an mcp_call output item - fn build_mcp_call_item( - tool_name: &str, - arguments: &str, - output: &str, - server_label: &str, - success: bool, - error: Option<&str>, - ) -> Value { - json!({ - "id": Self::generate_mcp_id("mcp"), - "type": event_types::ITEM_TYPE_MCP_CALL, - "status": if success { "completed" } else { "failed" }, - "approval_request_id": Value::Null, - "arguments": arguments, - "error": error, - "name": tool_name, - "output": output, - "server_label": server_label - }) - } -} - -#[async_trait] -impl super::super::RouterTrait for OpenAIRouter { - fn as_any(&self) -> &dyn Any { - self - } - - async fn health_generate(&self, _req: Request) -> Response { - // Simple upstream probe: GET {base}/v1/models without auth - let url = format!("{}/v1/models", self.base_url); - match self - .client - .get(&url) - .timeout(std::time::Duration::from_secs(2)) - .send() - .await - { - Ok(resp) => { - let code = resp.status(); - // Treat success and auth-required as healthy (endpoint reachable) - if code.is_success() || code.as_u16() == 401 || code.as_u16() == 403 { - (StatusCode::OK, "OK").into_response() - } else { - ( - StatusCode::SERVICE_UNAVAILABLE, - format!("Upstream status: {}", code), - ) - .into_response() - } - } - Err(e) => ( - StatusCode::SERVICE_UNAVAILABLE, - format!("Upstream error: {}", e), - ) - .into_response(), - } - } - - async fn get_server_info(&self, _req: Request) -> Response { - let info = json!({ - "router_type": "openai", - "workers": 1, - "base_url": &self.base_url - }); - (StatusCode::OK, info.to_string()).into_response() - } - - async fn get_models(&self, req: Request) -> Response { - // Proxy to upstream /v1/models; forward Authorization header if provided - let headers = req.headers(); - - let mut upstream = self.client.get(format!("{}/v1/models", self.base_url)); - - if let Some(auth) = headers - .get("authorization") - .or_else(|| headers.get("Authorization")) - { - upstream = upstream.header("Authorization", auth); - } - - match upstream.send().await { - Ok(res) => { - let status = StatusCode::from_u16(res.status().as_u16()) - .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); - let content_type = res.headers().get(CONTENT_TYPE).cloned(); - match res.bytes().await { - Ok(body) => { - let mut response = Response::new(Body::from(body)); - *response.status_mut() = status; - if let Some(ct) = content_type { - response.headers_mut().insert(CONTENT_TYPE, ct); - } - response - } - Err(e) => ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Failed to read upstream response: {}", e), - ) - .into_response(), - } - } - Err(e) => ( - StatusCode::BAD_GATEWAY, - format!("Failed to contact upstream: {}", e), - ) - .into_response(), - } - } - - async fn get_model_info(&self, _req: Request) -> Response { - // Not directly supported without model param; return 501 - ( - StatusCode::NOT_IMPLEMENTED, - "get_model_info not implemented for OpenAI router", - ) - .into_response() - } - - async fn route_generate( - &self, - _headers: Option<&HeaderMap>, - _body: &GenerateRequest, - _model_id: Option<&str>, - ) -> Response { - // Generate endpoint is SGLang-specific, not supported for OpenAI backend - ( - StatusCode::NOT_IMPLEMENTED, - "Generate endpoint not supported for OpenAI backend", - ) - .into_response() - } - - async fn route_chat( - &self, - headers: Option<&HeaderMap>, - body: &ChatCompletionRequest, - _model_id: Option<&str>, - ) -> Response { - if !self.circuit_breaker.can_execute() { - return (StatusCode::SERVICE_UNAVAILABLE, "Circuit breaker open").into_response(); - } - - // Serialize request body, removing SGLang-only fields - let mut payload = match to_value(body) { - Ok(v) => v, - Err(e) => { - return ( - StatusCode::BAD_REQUEST, - format!("Failed to serialize request: {}", e), - ) - .into_response(); - } - }; - if let Some(obj) = payload.as_object_mut() { - for key in [ - "top_k", - "min_p", - "min_tokens", - "regex", - "ebnf", - "stop_token_ids", - "no_stop_trim", - "ignore_eos", - "continue_final_message", - "skip_special_tokens", - "lora_path", - "session_params", - "separate_reasoning", - "stream_reasoning", - "chat_template_kwargs", - "return_hidden_states", - "repetition_penalty", - "sampling_seed", - ] { - obj.remove(key); - } - } - - let url = format!("{}/v1/chat/completions", self.base_url); - let mut req = self.client.post(&url).json(&payload); - - // Forward Authorization header if provided - if let Some(h) = headers { - if let Some(auth) = h.get("authorization").or_else(|| h.get("Authorization")) { - req = req.header("Authorization", auth); - } - } - - // Accept SSE when stream=true - if body.stream { - req = req.header("Accept", "text/event-stream"); - } - - let resp = match req.send().await { - Ok(r) => r, - Err(e) => { - self.circuit_breaker.record_failure(); - return ( - StatusCode::SERVICE_UNAVAILABLE, - format!("Failed to contact upstream: {}", e), - ) - .into_response(); - } - }; - - let status = StatusCode::from_u16(resp.status().as_u16()) - .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); - - if !body.stream { - // Capture Content-Type before consuming response body - let content_type = resp.headers().get(CONTENT_TYPE).cloned(); - match resp.bytes().await { - Ok(body) => { - self.circuit_breaker.record_success(); - let mut response = Response::new(Body::from(body)); - *response.status_mut() = status; - if let Some(ct) = content_type { - response.headers_mut().insert(CONTENT_TYPE, ct); - } - response - } - Err(e) => { - self.circuit_breaker.record_failure(); - ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Failed to read response: {}", e), - ) - .into_response() - } - } - } else { - // Stream SSE bytes to client - let stream = resp.bytes_stream(); - let (tx, rx) = mpsc::unbounded_channel(); - tokio::spawn(async move { - let mut s = stream; - while let Some(chunk) = s.next().await { - match chunk { - Ok(bytes) => { - if tx.send(Ok(bytes)).is_err() { - break; - } - } - Err(e) => { - let _ = tx.send(Err(format!("Stream error: {}", e))); - break; - } - } - } - }); - let mut response = Response::new(Body::from_stream(UnboundedReceiverStream::new(rx))); - *response.status_mut() = status; - response - .headers_mut() - .insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream")); - response - } - } - - async fn route_completion( - &self, - _headers: Option<&HeaderMap>, - _body: &CompletionRequest, - _model_id: Option<&str>, - ) -> Response { - // Completion endpoint not implemented for OpenAI backend - ( - StatusCode::NOT_IMPLEMENTED, - "Completion endpoint not implemented for OpenAI backend", - ) - .into_response() - } - - async fn route_responses( - &self, - headers: Option<&HeaderMap>, - body: &ResponsesRequest, - model_id: Option<&str>, - ) -> Response { - let url = format!("{}/v1/responses", self.base_url); - - info!( - requested_store = body.store, - is_streaming = body.stream, - "openai_responses_request" - ); - - // Validate mutually exclusive params: previous_response_id and conversation - // TODO: this validation logic should move the right place, also we need a proper error message module - if body.previous_response_id.is_some() && body.conversation.is_some() { - return ( - StatusCode::BAD_REQUEST, - Json(json!({ - "error": { - "message": "Mutually exclusive parameters. Ensure you are only providing one of: 'previous_response_id' or 'conversation'.", - "type": "invalid_request_error", - "param": Value::Null, - "code": "mutually_exclusive_parameters" - } - })), - ) - .into_response(); - } - - // Clone the body and override model if needed - let mut request_body = body.clone(); - if let Some(model) = model_id { - request_body.model = Some(model.to_string()); - } - // Do not forward conversation field upstream; retain for local persistence only - request_body.conversation = None; - - // Store the original previous_response_id for the response - let original_previous_response_id = request_body.previous_response_id.clone(); - - // Handle previous_response_id by loading prior context - let mut conversation_items: Option> = None; - if let Some(prev_id_str) = request_body.previous_response_id.clone() { - let prev_id = ResponseId::from(prev_id_str.as_str()); - match self - .response_storage - .get_response_chain(&prev_id, None) - .await - { - Ok(chain) => { - if !chain.responses.is_empty() { - let mut items = Vec::new(); - for stored in chain.responses.iter() { - let trimmed_id = stored.id.0.trim_start_matches("resp_"); - if !stored.input.is_empty() { - items.push(ResponseInputOutputItem::Message { - id: format!("msg_u_{}", trimmed_id), - role: "user".to_string(), - status: Some("completed".to_string()), - content: vec![ResponseContentPart::InputText { - text: stored.input.clone(), - }], - }); - } - if !stored.output.is_empty() { - items.push(ResponseInputOutputItem::Message { - id: format!("msg_a_{}", trimmed_id), - role: "assistant".to_string(), - status: Some("completed".to_string()), - content: vec![ResponseContentPart::OutputText { - text: stored.output.clone(), - annotations: vec![], - logprobs: None, - }], - }); - } - } - conversation_items = Some(items); - } else { - info!(previous_response_id = %prev_id_str, "previous chain empty"); - } - } - Err(err) => { - warn!(previous_response_id = %prev_id_str, %err, "failed to fetch previous response chain"); - } - } - // Clear previous_response_id from request since we're converting to conversation - request_body.previous_response_id = None; - } - - // If conversation is provided, attach its items as input to upstream request - if let Some(conv_id_str) = body.conversation.clone() { - let conv_id: ConversationId = conv_id_str.as_str().into(); - let mut items: Vec = Vec::new(); - // Fetch up to MAX_CONVERSATION_HISTORY_ITEMS items in ascending order - let params = ConversationItemsListParams { - limit: Self::MAX_CONVERSATION_HISTORY_ITEMS, - order: ConversationItemsSortOrder::Asc, - after: None, - }; - match self - .conversation_item_storage - .list_items(&conv_id, params) - .await - { - Ok(stored_items) => { - for it in stored_items { - match it.item_type.as_str() { - "message" => { - // content is expected to be an array of ResponseContentPart - let parts: Vec = match serde_json::from_value( - it.content.clone(), - ) { - Ok(parts) => parts, - Err(e) => { - warn!( - item_id = %it.id.0, - error = %e, - "Failed to deserialize conversation item content; skipping message item" - ); - continue; - } - }; - let role = it.role.unwrap_or_else(|| "user".to_string()); - items.push(ResponseInputOutputItem::Message { - id: it.id.0, - role, - content: parts, - status: it.status, - }); - } - _ => { - // Skip unsupported types for request input (e.g., MCP items) - } - } - } - } - Err(err) => { - warn!(conversation_id = %conv_id.0, error = %err.to_string(), "Failed to load conversation items for request input"); - } - } - - // Append the current request input at the end - match &request_body.input { - ResponseInput::Text(text) => { - items.push(ResponseInputOutputItem::Message { - id: format!("msg_u_current_{}", items.len()), - role: "user".to_string(), - status: Some("completed".to_string()), - content: vec![ResponseContentPart::InputText { text: text.clone() }], - }); - } - ResponseInput::Items(existing) => { - items.extend(existing.clone()); - } - } - request_body.input = ResponseInput::Items(items); - } - - if let Some(mut items) = conversation_items { - match &request_body.input { - ResponseInput::Text(text) => { - items.push(ResponseInputOutputItem::Message { - id: format!("msg_u_current_{}", items.len()), - role: "user".to_string(), - status: Some("completed".to_string()), - content: vec![ResponseContentPart::InputText { text: text.clone() }], - }); - } - ResponseInput::Items(existing) => { - items.extend(existing.clone()); - } - } - request_body.input = ResponseInput::Items(items); - } - - // Always set store=false for OpenAI (we store internally) - request_body.store = false; - - // Convert to JSON payload and strip SGLang-specific fields before forwarding - let mut payload = match to_value(&request_body) { - Ok(value) => value, - Err(err) => { - return ( - StatusCode::BAD_REQUEST, - format!("Failed to serialize responses request: {}", err), - ) - .into_response(); - } - }; - if let Some(obj) = payload.as_object_mut() { - for key in [ - "request_id", - "priority", - "frequency_penalty", - "presence_penalty", - "stop", - "top_k", - "min_p", - "repetition_penalty", - "conversation", - ] { - obj.remove(key); - } - } - - // Check if streaming is requested - if body.stream { - // Handle streaming response - self.handle_streaming_response( - url, - headers, - payload, - body, - original_previous_response_id, - ) - .await - } else { - // Handle non-streaming response - self.handle_non_streaming_response( - url, - headers, - payload, - body, - original_previous_response_id, - ) - .await - } - } - - async fn get_response( - &self, - _headers: Option<&HeaderMap>, - response_id: &str, - params: &ResponsesGetParams, - ) -> Response { - let stored_id = ResponseId::from(response_id); - if let Ok(Some(stored_response)) = self.response_storage.get_response(&stored_id).await { - let stream_requested = params.stream.unwrap_or(false); - let raw_value = stored_response.raw_response.clone(); - - if !raw_value.is_null() { - if stream_requested { - return ( - StatusCode::NOT_IMPLEMENTED, - "Streaming retrieval not yet implemented", - ) - .into_response(); - } - - return ( - StatusCode::OK, - [("content-type", "application/json")], - raw_value.to_string(), - ) - .into_response(); - } - - let openai_response = ResponsesResponse { - id: stored_response.id.0.clone(), - object: "response".to_string(), - created_at: stored_response.created_at.timestamp(), - status: ResponseStatus::Completed, - error: None, - incomplete_details: None, - instructions: stored_response.instructions.clone(), - max_output_tokens: None, - model: stored_response - .model - .unwrap_or_else(|| "gpt-4o".to_string()), - output: vec![ResponseOutputItem::Message { - id: format!("msg_{}", stored_response.id.0), - role: "assistant".to_string(), - status: "completed".to_string(), - content: vec![ResponseContentPart::OutputText { - text: stored_response.output, - annotations: vec![], - logprobs: None, - }], - }], - parallel_tool_calls: true, - previous_response_id: stored_response.previous_response_id.map(|id| id.0), - reasoning: None, - store: true, - temperature: Some(1.0), - text: Some(ResponseTextFormat { - format: TextFormatType { - format_type: "text".to_string(), - }, - }), - tool_choice: "auto".to_string(), - tools: vec![], - top_p: Some(1.0), - truncation: Some("disabled".to_string()), - usage: None, - user: stored_response.user.clone(), - metadata: stored_response.metadata.clone(), - }; - - if stream_requested { - return ( - StatusCode::NOT_IMPLEMENTED, - "Streaming retrieval not yet implemented", - ) - .into_response(); - } - - return ( - StatusCode::OK, - [("content-type", "application/json")], - serde_json::to_string(&openai_response).unwrap_or_else(|e| { - format!("{{\"error\": \"Failed to serialize response: {}\"}}", e) - }), - ) - .into_response(); - } - - ( - StatusCode::NOT_FOUND, - format!( - "Response with id '{}' not found in local storage", - response_id - ), - ) - .into_response() - } - - async fn cancel_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response { - // Forward to OpenAI's cancel endpoint - let url = format!("{}/v1/responses/{}/cancel", self.base_url, response_id); - - let request_builder = self.client.post(&url); - - // Apply headers with filtering (skip content headers for POST without body) - let request_builder = if let Some(headers) = headers { - apply_request_headers(headers, request_builder, true) - } else { - request_builder - }; - - match request_builder.send().await { - Ok(response) => { - let status = response.status(); - let headers = response.headers().clone(); - - match response.text().await { - Ok(body_text) => { - let mut response = (status, body_text).into_response(); - *response.headers_mut() = preserve_response_headers(&headers); - response - } - Err(e) => ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Failed to read response body: {}", e), - ) - .into_response(), - } - } - Err(e) => ( - StatusCode::BAD_GATEWAY, - format!("Failed to cancel response on OpenAI: {}", e), - ) - .into_response(), - } - } - - async fn route_embeddings( - &self, - _headers: Option<&HeaderMap>, - _body: &EmbeddingRequest, - _model_id: Option<&str>, - ) -> Response { - ( - StatusCode::FORBIDDEN, - "Embeddings endpoint not supported for OpenAI backend", - ) - .into_response() - } - - async fn route_rerank( - &self, - _headers: Option<&HeaderMap>, - _body: &RerankRequest, - _model_id: Option<&str>, - ) -> Response { - ( - StatusCode::FORBIDDEN, - "Rerank endpoint not supported for OpenAI backend", - ) - .into_response() - } - - async fn create_conversation(&self, _headers: Option<&HeaderMap>, body: &Value) -> Response { - // TODO: move this spec validation to the right place - let metadata = match body.get("metadata") { - Some(Value::Object(map)) => { - if map.len() > MAX_METADATA_PROPERTIES { - return ( - StatusCode::BAD_REQUEST, - Json(json!({ - "error": { - "message": format!( - "Invalid 'metadata': too many properties. Max {}, got {}", - MAX_METADATA_PROPERTIES, map.len() - ), - "type": "invalid_request_error", - "param": "metadata", - "code": "metadata_max_properties_exceeded" - } - })), - ) - .into_response(); - } - Some(map.clone()) - } - Some(Value::Null) | None => None, - Some(other) => { - return ( - StatusCode::BAD_REQUEST, - Json(json!({ - "error": { - "message": format!( - "Invalid 'metadata': expected object or null but got {}", - other - ), - "type": "invalid_request_error", - "param": "metadata", - "code": "metadata_invalid_type" - } - })), - ) - .into_response(); - } - }; - - match self - .conversation_storage - .create_conversation(crate::data_connector::NewConversation { metadata }) - .await - { - Ok(conversation) => { - (StatusCode::OK, Json(conversation_to_json(&conversation))).into_response() - } - Err(err) => ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({ - "error": { - "message": err.to_string(), - "type": "internal_error", - "param": Value::Null, - "code": Value::Null - } - })), - ) - .into_response(), - } - } - - async fn get_conversation( - &self, - _headers: Option<&HeaderMap>, - conversation_id: &str, - ) -> Response { - let id: ConversationId = conversation_id.to_string().into(); - match self.conversation_storage.get_conversation(&id).await { - Ok(Some(conv)) => (StatusCode::OK, Json(conversation_to_json(&conv))).into_response(), - Ok(None) => ( - StatusCode::NOT_FOUND, - Json(json!({ - "error": { - "message": format!("Conversation with id '{}' not found.", conversation_id), - "type": "invalid_request_error", - "param": Value::Null, - "code": Value::Null - } - })), - ) - .into_response(), - Err(err) => ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({ - "error": { - "message": err.to_string(), - "type": "internal_error", - "param": Value::Null, - "code": Value::Null - } - })), - ) - .into_response(), - } - } - - async fn update_conversation( - &self, - _headers: Option<&HeaderMap>, - conversation_id: &str, - body: &Value, - ) -> Response { - let id: ConversationId = conversation_id.to_string().into(); - let existing = match self.conversation_storage.get_conversation(&id).await { - Ok(Some(c)) => c, - Ok(None) => { - return ( - StatusCode::NOT_FOUND, - Json(json!({ - "error": { - "message": format!("Conversation with id '{}' not found.", conversation_id), - "type": "invalid_request_error", - "param": Value::Null, - "code": Value::Null - } - })), - ) - .into_response(); - } - Err(err) => { - return ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({ - "error": { - "message": err.to_string(), - "type": "internal_error", - "param": Value::Null, - "code": Value::Null - } - })), - ) - .into_response(); - } - }; - - // Parse metadata patch - enum Patch { - NoChange, - ClearAll, - Merge(ConversationMetadata), - } - let patch = match body.get("metadata") { - None => Patch::NoChange, - Some(Value::Null) => Patch::ClearAll, - Some(Value::Object(map)) => Patch::Merge(map.clone()), - Some(other) => { - return ( - StatusCode::BAD_REQUEST, - Json(json!({ - "error": { - "message": format!( - "Invalid 'metadata': expected object or null but got {}", - other - ), - "type": "invalid_request_error", - "param": "metadata", - "code": "metadata_invalid_type" - } - })), - ) - .into_response(); - } - }; - - let merged_metadata = match patch { - Patch::NoChange => { - return (StatusCode::OK, Json(conversation_to_json(&existing))).into_response(); - } - Patch::ClearAll => None, - Patch::Merge(upd) => { - let mut merged = existing.metadata.clone().unwrap_or_default(); - let previous = merged.len(); - for (k, v) in upd.into_iter() { - if v.is_null() { - merged.remove(&k); - } else { - merged.insert(k, v); - } - } - let updated = merged.len(); - if updated > MAX_METADATA_PROPERTIES { - return ( - StatusCode::BAD_REQUEST, - Json(json!({ - "error": { - "message": format!( - "Invalid 'metadata': too many properties after update. Max {} ({} -> {}).", - MAX_METADATA_PROPERTIES, previous, updated - ), - "type": "invalid_request_error", - "param": "metadata", - "code": "metadata_max_properties_exceeded", - "extra": { - "previous_property_count": previous, - "updated_property_count": updated - } - } - })), - ) - .into_response(); - } - if merged.is_empty() { - None - } else { - Some(merged) - } - } - }; - - match self - .conversation_storage - .update_conversation(&id, merged_metadata) - .await - { - Ok(Some(conv)) => (StatusCode::OK, Json(conversation_to_json(&conv))).into_response(), - Ok(None) => ( - StatusCode::NOT_FOUND, - Json(json!({ - "error": { - "message": format!("Conversation with id '{}' not found.", conversation_id), - "type": "invalid_request_error", - "param": Value::Null, - "code": Value::Null - } - })), - ) - .into_response(), - Err(err) => ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({ - "error": { - "message": err.to_string(), - "type": "internal_error", - "param": Value::Null, - "code": Value::Null - } - })), - ) - .into_response(), - } - } - - async fn delete_conversation( - &self, - _headers: Option<&HeaderMap>, - conversation_id: &str, - ) -> Response { - let id: ConversationId = conversation_id.to_string().into(); - match self.conversation_storage.delete_conversation(&id).await { - Ok(true) => ( - StatusCode::OK, - Json(json!({ - "id": conversation_id, - "object": "conversation.deleted", - "deleted": true - })), - ) - .into_response(), - Ok(false) => ( - StatusCode::NOT_FOUND, - Json(json!({ - "error": { - "message": format!("Conversation with id '{}' not found.", conversation_id), - "type": "invalid_request_error", - "param": Value::Null, - "code": Value::Null - } - })), - ) - .into_response(), - Err(err) => ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({ - "error": { - "message": err.to_string(), - "type": "internal_error", - "param": Value::Null, - "code": Value::Null - } - })), - ) - .into_response(), - } - } - - fn router_type(&self) -> &'static str { - "openai" - } - - async fn list_conversation_items( - &self, - _headers: Option<&HeaderMap>, - conversation_id: &str, - limit: Option, - order: Option, - after: Option, - ) -> Response { - let id: ConversationId = conversation_id.into(); - match self.conversation_storage.get_conversation(&id).await { - Ok(Some(_)) => {} - Ok(None) => { - return ( - StatusCode::NOT_FOUND, - Json(json!({ - "error": { - "message": format!("Conversation with id '{}' not found.", conversation_id), - "type": "invalid_request_error", - "param": Value::Null, - "code": Value::Null - } - })), - ) - .into_response(); - } - Err(err) => { - return ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({ - "error": { - "message": err.to_string(), - "type": "internal_error", - "param": Value::Null, - "code": Value::Null - } - })), - ) - .into_response(); - } - } - - let lim = limit.unwrap_or(20).clamp(1, 100); - let sort = match order.as_deref() { - Some("asc") => ConversationItemsSortOrder::Asc, - _ => ConversationItemsSortOrder::Desc, - }; - let params = ConversationItemsListParams { - limit: lim + 1, - order: sort, - after, - }; - - match self.conversation_item_storage.list_items(&id, params).await { - Ok(mut items) => { - let has_more = items.len() > lim; - if has_more { - items.truncate(lim); - } - let data: Vec = items - .into_iter() - .map(|it| { - json!({ - "id": it.id.0, - "type": it.item_type, - "status": it.status.unwrap_or_else(|| "completed".to_string()), - "content": it.content, - "role": it.role, - }) - }) - .collect(); - let first_id = data - .first() - .and_then(|v| v.get("id")) - .cloned() - .unwrap_or(Value::Null); - let last_id = data - .last() - .and_then(|v| v.get("id")) - .cloned() - .unwrap_or(Value::Null); - ( - StatusCode::OK, - Json(json!({ - "object": "list", - "data": data, - "first_id": first_id, - "last_id": last_id, - "has_more": has_more - })), - ) - .into_response() - } - Err(err) => ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({ - "error": { - "message": err.to_string(), - "type": "internal_error", - "param": Value::Null, - "code": Value::Null - } - })), - ) - .into_response(), - } - } -} -// Maximum number of properties allowed in conversation metadata (align with server) -const MAX_METADATA_PROPERTIES: usize = 16; - -fn conversation_to_json(conversation: &Conversation) -> Value { - json!({ - "id": conversation.id.0, - "object": "conversation", - "created_at": conversation.created_at.timestamp(), - "metadata": to_value(&conversation.metadata).unwrap_or(Value::Null), - }) -} - -async fn persist_items_with_storages( - conv_storage: SharedConversationStorage, - item_storage: SharedConversationItemStorage, - conversation_id: String, - request: ResponsesRequest, - response: Value, -) -> Result<(), String> { - let conv_id: ConversationId = conversation_id.as_str().into(); - match conv_storage.get_conversation(&conv_id).await { - Ok(Some(_)) => {} - Ok(None) => { - warn!(conversation_id = %conv_id.0, "Conversation not found; skipping item persistence"); - return Ok(()); - } - Err(err) => return Err(err.to_string()), - } - - // Extract response_id once for attaching to both input and output items - let response_id_opt = response - .get("id") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()); - - // Helper to ensure status defaults to completed - async fn create_and_link_item( - item_storage: &SharedConversationItemStorage, - conv_id: &ConversationId, - mut new_item: DCNewConversationItem, - ) -> Result<(), String> { - if new_item.status.is_none() { - new_item.status = Some("completed".to_string()); - } - let created = item_storage - .create_item(new_item) - .await - .map_err(|e| e.to_string())?; - item_storage - .link_item(conv_id, &created.id, chrono::Utc::now()) - .await - .map_err(|e| e.to_string())?; - tracing::info!(conversation_id = %conv_id.0, item_id = %created.id.0, item_type = %created.item_type, "Persisted conversation item and link"); - Ok(()) - } - - match request.input.clone() { - ResponseInput::Text(text) => { - let new_item = DCNewConversationItem { - id: None, // generate new message id for input - response_id: response_id_opt.clone(), - item_type: "message".to_string(), - role: Some("user".to_string()), - content: json!([{ "type": "input_text", "text": text }]), - status: Some("completed".to_string()), - }; - create_and_link_item(&item_storage, &conv_id, new_item).await?; - } - ResponseInput::Items(items) => { - for input_item in items { - match input_item { - ResponseInputOutputItem::Message { - role, - content, - status, - .. - } => { - let content_v = - serde_json::to_value(&content).map_err(|e| e.to_string())?; - let new_item = DCNewConversationItem { - id: None, // generate new id for input items - response_id: response_id_opt.clone(), - item_type: "message".to_string(), - role: Some(role), - content: content_v, - status, - }; - create_and_link_item(&item_storage, &conv_id, new_item).await?; - } - ResponseInputOutputItem::Reasoning { - summary, - content, - status, - .. - } => { - let new_item = DCNewConversationItem { - id: None, // generate new id for input items - response_id: response_id_opt.clone(), - item_type: "reasoning".to_string(), - role: None, - content: json!({ "summary": summary, "content": content }), - status, - }; - create_and_link_item(&item_storage, &conv_id, new_item).await?; - } - ResponseInputOutputItem::FunctionToolCall { - name, - arguments, - output, - status, - .. - } => { - let new_item = DCNewConversationItem { - id: None, // generate new id for input items - response_id: response_id_opt.clone(), - item_type: "function_tool_call".to_string(), - role: None, - content: json!({ "name": name, "arguments": arguments, "output": output }), - status, - }; - create_and_link_item(&item_storage, &conv_id, new_item).await?; - } - } - } - } - } - - if let Some(output_array) = response.get("output").and_then(|v| v.as_array()) { - for item in output_array { - let item_type = match item.get("type").and_then(|v| v.as_str()) { - Some(t) => t, - None => continue, - }; - - match item_type { - "message" => { - let id_in = item - .get("id") - .and_then(|v| v.as_str()) - .map(|s| crate::data_connector::ConversationItemId(s.to_string())); - let role = item - .get("role") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()); - let content_v = item - .get("content") - .cloned() - .unwrap_or_else(|| Value::Array(Vec::new())); - let status = item - .get("status") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()); - let new_item = DCNewConversationItem { - id: id_in, - response_id: response_id_opt.clone(), - item_type: "message".to_string(), - role, - content: content_v, - status, - }; - create_and_link_item(&item_storage, &conv_id, new_item).await?; - } - "reasoning" => { - let id_in = item - .get("id") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()); - let summary_v = item - .get("summary") - .cloned() - .unwrap_or_else(|| Value::Array(Vec::new())); - let content_v = item - .get("content") - .cloned() - .unwrap_or_else(|| Value::Array(Vec::new())); - let status = item - .get("status") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()); - let new_item = DCNewConversationItem { - id: id_in.map(crate::data_connector::ConversationItemId), - response_id: response_id_opt.clone(), - item_type: "reasoning".to_string(), - role: None, - content: json!({ "summary": summary_v, "content": content_v }), - status, - }; - create_and_link_item(&item_storage, &conv_id, new_item).await?; - } - "function_tool_call" => { - let id_in = item - .get("id") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()); - let name = item.get("name").and_then(|v| v.as_str()).unwrap_or(""); - let arguments = item.get("arguments").and_then(|v| v.as_str()).unwrap_or(""); - let output_str = item.get("output").and_then(|v| v.as_str()).unwrap_or(""); - let status = item - .get("status") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()); - let new_item = DCNewConversationItem { - id: id_in.map(crate::data_connector::ConversationItemId), - response_id: response_id_opt.clone(), - item_type: "function_tool_call".to_string(), - role: None, - content: json!({ - "name": name, - "arguments": arguments, - "output": output_str - }), - status, - }; - create_and_link_item(&item_storage, &conv_id, new_item).await?; - } - "mcp_call" => { - let id_in = item - .get("id") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()); - let name = item.get("name").and_then(|v| v.as_str()).unwrap_or(""); - let arguments = item.get("arguments").and_then(|v| v.as_str()).unwrap_or(""); - let output_str = item.get("output").and_then(|v| v.as_str()).unwrap_or(""); - let status = item - .get("status") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()); - let content_v = json!({ - "server_label": item.get("server_label").cloned().unwrap_or(Value::Null), - "name": name, - "arguments": arguments, - "output": output_str, - "error": item.get("error").cloned().unwrap_or(Value::Null), - "approval_request_id": item.get("approval_request_id").cloned().unwrap_or(Value::Null) - }); - let new_item = DCNewConversationItem { - id: id_in.map(crate::data_connector::ConversationItemId), - response_id: response_id_opt.clone(), - item_type: "mcp_call".to_string(), - role: None, - content: content_v, - status, - }; - create_and_link_item(&item_storage, &conv_id, new_item).await?; - } - "mcp_list_tools" => { - let id_in = item - .get("id") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()); - let content_v = json!({ - "server_label": item.get("server_label").cloned().unwrap_or(Value::Null), - "tools": item.get("tools").cloned().unwrap_or_else(|| Value::Array(Vec::new())) - }); - let new_item = DCNewConversationItem { - id: id_in.map(crate::data_connector::ConversationItemId), - response_id: response_id_opt.clone(), - item_type: "mcp_list_tools".to_string(), - role: None, - content: content_v, - status: Some("completed".to_string()), - }; - create_and_link_item(&item_storage, &conv_id, new_item).await?; - } - _ => {} - } - } - } - - Ok(()) -} diff --git a/sgl-router/src/routers/mod.rs b/sgl-router/src/routers/mod.rs index a74503424..29d4a6c7f 100644 --- a/sgl-router/src/routers/mod.rs +++ b/sgl-router/src/routers/mod.rs @@ -19,12 +19,13 @@ pub mod factory; pub mod grpc; pub mod header_utils; pub mod http; +pub mod openai; // New refactored OpenAI router module pub mod router_manager; pub use factory::RouterFactory; -// Re-export HTTP routers for convenience (keeps routers::openai_router path working) -pub use http::{openai_router, pd_router, pd_types, router}; +// Re-export HTTP routers for convenience +pub use http::{pd_router, pd_types, router}; /// Core trait for all router implementations /// diff --git a/sgl-router/src/routers/openai/conversations.rs b/sgl-router/src/routers/openai/conversations.rs new file mode 100644 index 000000000..e98ad11ef --- /dev/null +++ b/sgl-router/src/routers/openai/conversations.rs @@ -0,0 +1,574 @@ +//! Conversation CRUD operations and persistence + +use crate::data_connector::{ + conversation_items::ListParams, conversation_items::SortOrder, Conversation, ConversationId, + ConversationItemStorage, ConversationStorage, NewConversation, NewConversationItem, ResponseId, + ResponseStorage, SharedConversationItemStorage, SharedConversationStorage, +}; +use crate::protocols::spec::{ResponseInput, ResponsesRequest}; +use axum::http::StatusCode; +use axum::response::{IntoResponse, Response}; +use axum::Json; +use chrono::Utc; +use serde_json::{json, Value}; +use std::collections::HashMap; +use std::sync::Arc; +use tracing::{info, warn}; + +use super::responses::build_stored_response; + +/// Maximum number of properties allowed in conversation metadata +pub(crate) const MAX_METADATA_PROPERTIES: usize = 16; + +// ============================================================================ +// Conversation CRUD Operations +// ============================================================================ + +/// Create a new conversation +pub(super) async fn create_conversation( + conversation_storage: &SharedConversationStorage, + body: Value, +) -> Response { + // TODO: The validation should be done in the right place + let metadata = match body.get("metadata") { + Some(Value::Object(map)) => { + if map.len() > MAX_METADATA_PROPERTIES { + return ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": format!( + "metadata cannot have more than {} properties", + MAX_METADATA_PROPERTIES + ) + })), + ) + .into_response(); + } + Some(map.clone()) + } + Some(_) => { + return ( + StatusCode::BAD_REQUEST, + Json(json!({"error": "metadata must be an object"})), + ) + .into_response(); + } + None => None, + }; + + let new_conv = NewConversation { metadata }; + + match conversation_storage.create_conversation(new_conv).await { + Ok(conversation) => { + info!(conversation_id = %conversation.id.0, "Created conversation"); + (StatusCode::OK, Json(conversation_to_json(&conversation))).into_response() + } + Err(e) => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": format!("Failed to create conversation: {}", e)})), + ) + .into_response(), + } +} + +/// Get a conversation by ID +pub(super) async fn get_conversation( + conversation_storage: &SharedConversationStorage, + conv_id: &str, +) -> Response { + let conversation_id = ConversationId::from(conv_id); + + match conversation_storage + .get_conversation(&conversation_id) + .await + { + Ok(Some(conversation)) => { + (StatusCode::OK, Json(conversation_to_json(&conversation))).into_response() + } + Ok(None) => ( + StatusCode::NOT_FOUND, + Json(json!({"error": "Conversation not found"})), + ) + .into_response(), + Err(e) => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": format!("Failed to get conversation: {}", e)})), + ) + .into_response(), + } +} + +/// Update a conversation's metadata +pub(super) async fn update_conversation( + conversation_storage: &SharedConversationStorage, + conv_id: &str, + body: Value, +) -> Response { + let conversation_id = ConversationId::from(conv_id); + + let current_meta = match conversation_storage + .get_conversation(&conversation_id) + .await + { + Ok(Some(meta)) => meta, + Ok(None) => { + return ( + StatusCode::NOT_FOUND, + Json(json!({"error": "Conversation not found"})), + ) + .into_response(); + } + Err(e) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": format!("Failed to get conversation: {}", e)})), + ) + .into_response(); + } + }; + + #[derive(Debug)] + enum Patch { + Set(String, Value), + Delete(String), + } + + let mut patches: Vec = Vec::new(); + + if let Some(metadata_val) = body.get("metadata") { + if let Some(map) = metadata_val.as_object() { + for (k, v) in map { + if v.is_null() { + patches.push(Patch::Delete(k.clone())); + } else { + patches.push(Patch::Set(k.clone(), v.clone())); + } + } + } else { + return ( + StatusCode::BAD_REQUEST, + Json(json!({"error": "metadata must be an object"})), + ) + .into_response(); + } + } + + let mut new_metadata = current_meta.metadata.clone().unwrap_or_default(); + for patch in patches { + match patch { + Patch::Set(k, v) => { + new_metadata.insert(k, v); + } + Patch::Delete(k) => { + new_metadata.remove(&k); + } + } + } + + if new_metadata.len() > MAX_METADATA_PROPERTIES { + return ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": format!( + "metadata cannot have more than {} properties", + MAX_METADATA_PROPERTIES + ) + })), + ) + .into_response(); + } + + let final_metadata = if new_metadata.is_empty() { + None + } else { + Some(new_metadata) + }; + + match conversation_storage + .update_conversation(&conversation_id, final_metadata) + .await + { + Ok(Some(conversation)) => { + info!(conversation_id = %conversation_id.0, "Updated conversation"); + (StatusCode::OK, Json(conversation_to_json(&conversation))).into_response() + } + Ok(None) => ( + StatusCode::NOT_FOUND, + Json(json!({"error": "Conversation not found"})), + ) + .into_response(), + Err(e) => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": format!("Failed to update conversation: {}", e)})), + ) + .into_response(), + } +} + +/// Delete a conversation +pub(super) async fn delete_conversation( + conversation_storage: &SharedConversationStorage, + conv_id: &str, +) -> Response { + let conversation_id = ConversationId::from(conv_id); + + match conversation_storage + .get_conversation(&conversation_id) + .await + { + Ok(Some(_)) => {} + Ok(None) => { + return ( + StatusCode::NOT_FOUND, + Json(json!({"error": "Conversation not found"})), + ) + .into_response(); + } + Err(e) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": format!("Failed to get conversation: {}", e)})), + ) + .into_response(); + } + } + + match conversation_storage + .delete_conversation(&conversation_id) + .await + { + Ok(_) => { + info!(conversation_id = %conversation_id.0, "Deleted conversation"); + ( + StatusCode::OK, + Json(json!({ + "id": conversation_id.0, + "object": "conversation.deleted", + "deleted": true + })), + ) + .into_response() + } + Err(e) => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": format!("Failed to delete conversation: {}", e)})), + ) + .into_response(), + } +} + +/// List items in a conversation with pagination +pub(super) async fn list_conversation_items( + conversation_storage: &SharedConversationStorage, + item_storage: &SharedConversationItemStorage, + conv_id: &str, + query_params: HashMap, +) -> Response { + let conversation_id = ConversationId::from(conv_id); + + match conversation_storage + .get_conversation(&conversation_id) + .await + { + Ok(Some(_)) => {} + Ok(None) => { + return ( + StatusCode::NOT_FOUND, + Json(json!({"error": "Conversation not found"})), + ) + .into_response(); + } + Err(e) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": format!("Failed to get conversation: {}", e)})), + ) + .into_response(); + } + } + + let limit: usize = query_params + .get("limit") + .and_then(|s| s.parse().ok()) + .unwrap_or(100); + + let after = query_params.get("after").map(|s| s.to_string()); + + // Default to descending order (most recent first) + let order = query_params + .get("order") + .and_then(|s| match s.as_str() { + "asc" => Some(SortOrder::Asc), + "desc" => Some(SortOrder::Desc), + _ => None, + }) + .unwrap_or(SortOrder::Desc); + + let params = ListParams { + limit, + order, + after, + }; + + match item_storage.list_items(&conversation_id, params).await { + Ok(items) => { + let item_values: Vec = items + .iter() + .map(|item| { + let mut obj = serde_json::Map::new(); + obj.insert("id".to_string(), json!(item.id.0)); + obj.insert("type".to_string(), json!(item.item_type)); + obj.insert("created_at".to_string(), json!(item.created_at)); + + obj.insert("content".to_string(), item.content.clone()); + if let Some(status) = &item.status { + obj.insert("status".to_string(), json!(status)); + } + + Value::Object(obj) + }) + .collect(); + + let has_more = items.len() == limit; + let last_id = items.last().map(|item| item.id.0.clone()); + + ( + StatusCode::OK, + Json(json!({ + "object": "list", + "data": item_values, + "has_more": has_more, + "first_id": items.first().map(|item| &item.id.0), + "last_id": last_id, + })), + ) + .into_response() + } + Err(e) => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": format!("Failed to list items: {}", e)})), + ) + .into_response(), + } +} + +// ============================================================================ +// Persistence Operations +// ============================================================================ + +/// Persist conversation items (delegates to persist_items_with_storages) +pub(super) async fn persist_conversation_items( + conversation_storage: Arc, + item_storage: Arc, + response_storage: Arc, + response_json: &Value, + original_body: &ResponsesRequest, +) -> Result<(), String> { + persist_items_with_storages( + conversation_storage, + item_storage, + response_storage, + response_json, + original_body, + ) + .await +} + +/// Helper function to create and link a conversation item (two-step API) +async fn create_and_link_item( + item_storage: &Arc, + conv_id: &ConversationId, + mut new_item: NewConversationItem, +) -> Result<(), String> { + // Set default status if not provided + if new_item.status.is_none() { + new_item.status = Some("completed".to_string()); + } + + // Step 1: Create the item + let created = item_storage + .create_item(new_item) + .await + .map_err(|e| format!("Failed to create item: {}", e))?; + + // Step 2: Link it to the conversation + item_storage + .link_item(conv_id, &created.id, Utc::now()) + .await + .map_err(|e| format!("Failed to link item: {}", e))?; + + info!( + conversation_id = %conv_id.0, + item_id = %created.id.0, + item_type = %created.item_type, + "Persisted conversation item and link" + ); + + Ok(()) +} + +/// Persist conversation items with all storages +async fn persist_items_with_storages( + conversation_storage: Arc, + item_storage: Arc, + response_storage: Arc, + response_json: &Value, + original_body: &ResponsesRequest, +) -> Result<(), String> { + let conv_id = match &original_body.conversation { + Some(id) => ConversationId::from(id.as_str()), + None => return Ok(()), + }; + + if conversation_storage + .get_conversation(&conv_id) + .await + .map_err(|e| format!("Failed to get conversation: {}", e))? + .is_none() + { + warn!(conversation_id = %conv_id.0, "Conversation not found, skipping item persistence"); + return Ok(()); + } + + let response_id_str = response_json + .get("id") + .and_then(|v| v.as_str()) + .ok_or_else(|| "Response missing id field".to_string())?; + let response_id = ResponseId::from(response_id_str); + + let response_id_opt = Some(response_id_str.to_string()); + + // Persist input items + match &original_body.input { + ResponseInput::Text(text) => { + let new_item = NewConversationItem { + id: None, // Let storage generate ID + response_id: response_id_opt.clone(), + item_type: "message".to_string(), + role: Some("user".to_string()), + content: json!([{ "type": "input_text", "text": text }]), + status: Some("completed".to_string()), + }; + create_and_link_item(&item_storage, &conv_id, new_item).await?; + } + ResponseInput::Items(items_array) => { + for input_item in items_array { + match input_item { + crate::protocols::spec::ResponseInputOutputItem::Message { + role, + content, + status, + .. + } => { + let content_v = serde_json::to_value(content) + .map_err(|e| format!("Failed to serialize content: {}", e))?; + let new_item = NewConversationItem { + id: None, + response_id: response_id_opt.clone(), + item_type: "message".to_string(), + role: Some(role.clone()), + content: content_v, + status: status.clone(), + }; + create_and_link_item(&item_storage, &conv_id, new_item).await?; + } + _ => { + // For other types (FunctionToolCall, etc.), serialize the whole item + let item_val = serde_json::to_value(input_item) + .map_err(|e| format!("Failed to serialize item: {}", e))?; + let new_item = NewConversationItem { + id: None, + response_id: response_id_opt.clone(), + item_type: "unknown".to_string(), + role: None, + content: item_val, + status: Some("completed".to_string()), + }; + create_and_link_item(&item_storage, &conv_id, new_item).await?; + } + } + } + } + } + + // Persist output items + if let Some(output_arr) = response_json.get("output").and_then(|v| v.as_array()) { + for output_item in output_arr { + if let Some(obj) = output_item.as_object() { + let item_type = obj + .get("type") + .and_then(|v| v.as_str()) + .unwrap_or("message"); + + let role = obj.get("role").and_then(|v| v.as_str()).map(String::from); + let status = obj.get("status").and_then(|v| v.as_str()).map(String::from); + + let content = if item_type == "message" { + obj.get("content").cloned().unwrap_or(json!([])) + } else if item_type == "function_call" || item_type == "function_tool_call" { + json!({ + "type": "function_call", + "name": obj.get("name"), + "call_id": obj.get("call_id").or_else(|| obj.get("id")), + "arguments": obj.get("arguments") + }) + } else if item_type == "function_call_output" { + json!({ + "type": "function_call_output", + "call_id": obj.get("call_id"), + "output": obj.get("output") + }) + } else { + output_item.clone() + }; + + let new_item = NewConversationItem { + id: None, + response_id: response_id_opt.clone(), + item_type: item_type.to_string(), + role, + content, + status, + }; + create_and_link_item(&item_storage, &conv_id, new_item).await?; + } + } + } + + // Store the full response using the shared helper + let mut stored_response = build_stored_response(response_json, original_body); + stored_response.id = response_id; + let final_response_id = stored_response.id.clone(); + + response_storage + .store_response(stored_response) + .await + .map_err(|e| format!("Failed to store response in conversation: {}", e))?; + + info!(conversation_id = %conv_id.0, response_id = %final_response_id.0, "Persisted conversation items and response"); + + Ok(()) +} + +// ============================================================================ +// Helper Functions +// ============================================================================ + +/// Convert conversation to JSON response +fn conversation_to_json(conversation: &Conversation) -> Value { + let mut response = json!({ + "id": conversation.id.0, + "object": "conversation", + "created_at": conversation.created_at.timestamp() + }); + + if let Some(metadata) = &conversation.metadata { + if !metadata.is_empty() { + if let Some(obj) = response.as_object_mut() { + obj.insert("metadata".to_string(), Value::Object(metadata.clone())); + } + } + } + + response +} diff --git a/sgl-router/src/routers/openai/mcp.rs b/sgl-router/src/routers/openai/mcp.rs new file mode 100644 index 000000000..3d3cdcf55 --- /dev/null +++ b/sgl-router/src/routers/openai/mcp.rs @@ -0,0 +1,967 @@ +//! MCP (Model Context Protocol) Integration Module +//! +//! This module contains all MCP-related functionality for the OpenAI router: +//! - Tool loop state management for multi-turn tool calling +//! - MCP tool execution and result handling +//! - Output item builders for MCP-specific response formats +//! - SSE event generation for streaming MCP operations +//! - Payload transformation for MCP tool interception +//! - Metadata injection for MCP operations + +use crate::mcp::McpClientManager; +use crate::protocols::spec::{ResponseInput, ResponseToolType, ResponsesRequest}; +use crate::routers::header_utils::apply_request_headers; +use axum::http::HeaderMap; +use bytes::Bytes; +use serde_json::{json, to_value, Value}; +use std::{io, sync::Arc}; +use tokio::sync::mpsc; +use tracing::{info, warn}; + +use super::utils::event_types; + +// ============================================================================ +// Configuration and State Types +// ============================================================================ + +/// Configuration for MCP tool calling loops +#[allow(dead_code)] +#[derive(Debug, Clone)] +pub(crate) struct McpLoopConfig { + /// Maximum iterations as safety limit (internal only, default: 10) + /// Prevents infinite loops when max_tool_calls is not set + pub max_iterations: usize, +} + +impl Default for McpLoopConfig { + fn default() -> Self { + Self { max_iterations: 10 } + } +} + +/// State for tracking multi-turn tool calling loop +pub(crate) struct ToolLoopState { + /// Current iteration number (starts at 0, increments with each tool call) + pub iteration: usize, + /// Total number of tool calls executed + pub total_calls: usize, + /// Conversation history (function_call and function_call_output items) + pub conversation_history: Vec, + /// Original user input (preserved for building resume payloads) + pub original_input: ResponseInput, +} + +impl ToolLoopState { + pub fn new(original_input: ResponseInput) -> Self { + Self { + iteration: 0, + total_calls: 0, + conversation_history: Vec::new(), + original_input, + } + } + + /// Record a tool call in the loop state + pub fn record_call( + &mut self, + call_id: String, + tool_name: String, + args_json_str: String, + output_str: String, + ) { + // Add function_call item to history + let func_item = json!({ + "type": event_types::ITEM_TYPE_FUNCTION_CALL, + "call_id": call_id, + "name": tool_name, + "arguments": args_json_str + }); + self.conversation_history.push(func_item); + + // Add function_call_output item to history + let output_item = json!({ + "type": "function_call_output", + "call_id": call_id, + "output": output_str + }); + self.conversation_history.push(output_item); + } +} + +/// Represents a function call being accumulated across delta events +#[derive(Debug, Clone)] +pub(crate) struct FunctionCallInProgress { + pub call_id: String, + pub name: String, + pub arguments_buffer: String, + pub output_index: usize, + pub last_obfuscation: Option, + pub assigned_output_index: Option, +} + +impl FunctionCallInProgress { + pub fn new(call_id: String, output_index: usize) -> Self { + Self { + call_id, + name: String::new(), + arguments_buffer: String::new(), + output_index, + last_obfuscation: None, + assigned_output_index: None, + } + } + + pub fn is_complete(&self) -> bool { + // A tool call is complete if it has a name + !self.name.is_empty() + } + + pub fn effective_output_index(&self) -> usize { + self.assigned_output_index.unwrap_or(self.output_index) + } +} + +// ============================================================================ +// MCP Manager Integration +// ============================================================================ + +/// Build a request-scoped MCP manager from request tools, if present. +pub(super) async fn mcp_manager_from_request_tools( + tools: &[crate::protocols::spec::ResponseTool], +) -> Option> { + let tool = tools + .iter() + .find(|t| matches!(t.r#type, ResponseToolType::Mcp) && t.server_url.is_some())?; + let server_url = tool.server_url.as_ref()?.trim().to_string(); + if !(server_url.starts_with("http://") || server_url.starts_with("https://")) { + warn!( + "Ignoring MCP server_url with unsupported scheme: {}", + server_url + ); + return None; + } + let name = tool + .server_label + .clone() + .unwrap_or_else(|| "request-mcp".to_string()); + let token = tool.authorization.clone(); + let transport = if server_url.contains("/sse") { + crate::mcp::McpTransport::Sse { + url: server_url, + token, + } + } else { + crate::mcp::McpTransport::Streamable { + url: server_url, + token, + } + }; + let cfg = crate::mcp::McpConfig { + servers: vec![crate::mcp::McpServerConfig { name, transport }], + }; + match McpClientManager::new(cfg).await { + Ok(mgr) => Some(Arc::new(mgr)), + Err(err) => { + warn!("Failed to initialize request-scoped MCP manager: {}", err); + None + } + } +} + +// ============================================================================ +// Tool Execution +// ============================================================================ + +/// Execute an MCP tool call +pub(super) async fn execute_mcp_call( + mcp_mgr: &Arc, + tool_name: &str, + args_json_str: &str, +) -> Result<(String, String), String> { + let args_value: Value = + serde_json::from_str(args_json_str).map_err(|e| format!("parse tool args: {}", e))?; + let args_obj = args_value.as_object().cloned(); + + let server_name = mcp_mgr + .get_tool(tool_name) + .map(|t| t.server) + .ok_or_else(|| format!("tool not found: {}", tool_name))?; + + let result = mcp_mgr + .call_tool(tool_name, args_obj) + .await + .map_err(|e| format!("tool call failed: {}", e))?; + + let output_str = serde_json::to_string(&result) + .map_err(|e| format!("Failed to serialize tool result: {}", e))?; + Ok((server_name, output_str)) +} + +/// Execute detected tool calls and send completion events to client +/// Returns false if client disconnected during execution +pub(super) async fn execute_streaming_tool_calls( + pending_calls: Vec, + active_mcp: &Arc, + tx: &mpsc::UnboundedSender>, + state: &mut ToolLoopState, + server_label: &str, + sequence_number: &mut u64, +) -> bool { + // Execute all pending tool calls (sequential, as PR3 is skipped) + for call in pending_calls { + // Skip if name is empty (invalid call) + if call.name.is_empty() { + warn!( + "Skipping incomplete tool call: name is empty, args_len={}", + call.arguments_buffer.len() + ); + continue; + } + + info!( + "Executing tool call during streaming: {} ({})", + call.name, call.call_id + ); + + // Use empty JSON object if arguments_buffer is empty + let args_str = if call.arguments_buffer.is_empty() { + "{}" + } else { + &call.arguments_buffer + }; + + let call_result = execute_mcp_call(active_mcp, &call.name, args_str).await; + let (output_str, success, error_msg) = match call_result { + Ok((_, output)) => (output, true, None), + Err(err) => { + warn!("Tool execution failed during streaming: {}", err); + (json!({ "error": &err }).to_string(), false, Some(err)) + } + }; + + // Send mcp_call completion event to client + if !send_mcp_call_completion_events_with_error( + tx, + &call, + &output_str, + server_label, + success, + error_msg.as_deref(), + sequence_number, + ) { + // Client disconnected, no point continuing tool execution + return false; + } + + // Record the call + state.record_call(call.call_id, call.name, call.arguments_buffer, output_str); + } + true +} + +// ============================================================================ +// Payload Transformation +// ============================================================================ + +/// Transform payload to replace MCP tools with function tools for streaming +pub(super) fn prepare_mcp_payload_for_streaming( + payload: &mut Value, + active_mcp: &Arc, +) { + if let Some(obj) = payload.as_object_mut() { + // Remove any non-function tools from outgoing payload + if let Some(v) = obj.get_mut("tools") { + if let Some(arr) = v.as_array_mut() { + arr.retain(|item| { + item.get("type") + .and_then(|v| v.as_str()) + .map(|s| s == event_types::ITEM_TYPE_FUNCTION) + .unwrap_or(false) + }); + } + } + + // Build function tools for all discovered MCP tools + let mut tools_json = Vec::new(); + let tools = active_mcp.list_tools(); + for t in tools { + let parameters = t.parameters.clone().unwrap_or(serde_json::json!({ + "type": "object", + "properties": {}, + "additionalProperties": false + })); + let tool = serde_json::json!({ + "type": event_types::ITEM_TYPE_FUNCTION, + "name": t.name, + "description": t.description, + "parameters": parameters + }); + tools_json.push(tool); + } + if !tools_json.is_empty() { + obj.insert("tools".to_string(), Value::Array(tools_json)); + obj.insert("tool_choice".to_string(), Value::String("auto".to_string())); + } + } +} + +/// Build a resume payload with conversation history +pub(super) fn build_resume_payload( + base_payload: &Value, + conversation_history: &[Value], + original_input: &ResponseInput, + tools_json: &Value, + is_streaming: bool, +) -> Result { + // Clone the base payload which already has cleaned fields + let mut payload = base_payload.clone(); + + let obj = payload + .as_object_mut() + .ok_or_else(|| "payload not an object".to_string())?; + + // Build input array: start with original user input + let mut input_array = Vec::new(); + + // Add original user message + // For structured input, serialize the original input items + match original_input { + ResponseInput::Text(text) => { + let user_item = json!({ + "type": "message", + "role": "user", + "content": [{ "type": "input_text", "text": text }] + }); + input_array.push(user_item); + } + ResponseInput::Items(items) => { + // Items are already structured ResponseInputOutputItem, convert to JSON + if let Ok(items_value) = to_value(items) { + if let Some(items_arr) = items_value.as_array() { + input_array.extend_from_slice(items_arr); + } + } + } + } + + // Add all conversation history (function calls and outputs) + input_array.extend_from_slice(conversation_history); + + obj.insert("input".to_string(), Value::Array(input_array)); + + // Use the transformed tools (function tools, not MCP tools) + if let Some(tools_arr) = tools_json.as_array() { + if !tools_arr.is_empty() { + obj.insert("tools".to_string(), tools_json.clone()); + } + } + + // Set streaming mode based on caller's context + obj.insert("stream".to_string(), Value::Bool(is_streaming)); + obj.insert("store".to_string(), Value::Bool(false)); + + // Note: SGLang-specific fields were already removed from base_payload + // before it was passed to execute_tool_loop (see route_responses lines 1935-1946) + + Ok(payload) +} + +// ============================================================================ +// SSE Event Senders +// ============================================================================ + +/// Send mcp_list_tools events to client at the start of streaming +/// Returns false if client disconnected +pub(super) fn send_mcp_list_tools_events( + tx: &mpsc::UnboundedSender>, + mcp: &Arc, + server_label: &str, + output_index: usize, + sequence_number: &mut u64, +) -> bool { + let tools_item_full = build_mcp_list_tools_item(mcp, server_label); + let item_id = tools_item_full + .get("id") + .and_then(|v| v.as_str()) + .unwrap_or(""); + + // Create empty tools version for the initial added event + let mut tools_item_empty = tools_item_full.clone(); + if let Some(obj) = tools_item_empty.as_object_mut() { + obj.insert("tools".to_string(), json!([])); + } + + // Event 1: response.output_item.added with empty tools + let event1_payload = json!({ + "type": event_types::OUTPUT_ITEM_ADDED, + "sequence_number": *sequence_number, + "output_index": output_index, + "item": tools_item_empty + }); + *sequence_number += 1; + let event1 = format!( + "event: {}\ndata: {}\n\n", + event_types::OUTPUT_ITEM_ADDED, + event1_payload + ); + if tx.send(Ok(Bytes::from(event1))).is_err() { + return false; // Client disconnected + } + + // Event 2: response.mcp_list_tools.in_progress + let event2_payload = json!({ + "type": event_types::MCP_LIST_TOOLS_IN_PROGRESS, + "sequence_number": *sequence_number, + "output_index": output_index, + "item_id": item_id + }); + *sequence_number += 1; + let event2 = format!( + "event: {}\ndata: {}\n\n", + event_types::MCP_LIST_TOOLS_IN_PROGRESS, + event2_payload + ); + if tx.send(Ok(Bytes::from(event2))).is_err() { + return false; + } + + // Event 3: response.mcp_list_tools.completed + let event3_payload = json!({ + "type": event_types::MCP_LIST_TOOLS_COMPLETED, + "sequence_number": *sequence_number, + "output_index": output_index, + "item_id": item_id + }); + *sequence_number += 1; + let event3 = format!( + "event: {}\ndata: {}\n\n", + event_types::MCP_LIST_TOOLS_COMPLETED, + event3_payload + ); + if tx.send(Ok(Bytes::from(event3))).is_err() { + return false; + } + + // Event 4: response.output_item.done with full tools list + let event4_payload = json!({ + "type": event_types::OUTPUT_ITEM_DONE, + "sequence_number": *sequence_number, + "output_index": output_index, + "item": tools_item_full + }); + *sequence_number += 1; + let event4 = format!( + "event: {}\ndata: {}\n\n", + event_types::OUTPUT_ITEM_DONE, + event4_payload + ); + tx.send(Ok(Bytes::from(event4))).is_ok() +} + +/// Send mcp_call completion events after tool execution +/// Returns false if client disconnected +pub(super) fn send_mcp_call_completion_events_with_error( + tx: &mpsc::UnboundedSender>, + call: &FunctionCallInProgress, + output: &str, + server_label: &str, + success: bool, + error_msg: Option<&str>, + sequence_number: &mut u64, +) -> bool { + let effective_output_index = call.effective_output_index(); + + // Build mcp_call item (reuse existing function) + let mcp_call_item = build_mcp_call_item( + &call.name, + &call.arguments_buffer, + output, + server_label, + success, + error_msg, + ); + + // Get the mcp_call item_id + let item_id = mcp_call_item + .get("id") + .and_then(|v| v.as_str()) + .unwrap_or(""); + + // Event 1: response.mcp_call.completed + let completed_payload = json!({ + "type": event_types::MCP_CALL_COMPLETED, + "sequence_number": *sequence_number, + "output_index": effective_output_index, + "item_id": item_id + }); + *sequence_number += 1; + + let completed_event = format!( + "event: {}\ndata: {}\n\n", + event_types::MCP_CALL_COMPLETED, + completed_payload + ); + if tx.send(Ok(Bytes::from(completed_event))).is_err() { + return false; + } + + // Event 2: response.output_item.done (with completed mcp_call) + let done_payload = json!({ + "type": event_types::OUTPUT_ITEM_DONE, + "sequence_number": *sequence_number, + "output_index": effective_output_index, + "item": mcp_call_item + }); + *sequence_number += 1; + + let done_event = format!( + "event: {}\ndata: {}\n\n", + event_types::OUTPUT_ITEM_DONE, + done_payload + ); + tx.send(Ok(Bytes::from(done_event))).is_ok() +} + +// ============================================================================ +// Metadata Injection +// ============================================================================ + +/// Inject MCP metadata into a streaming response +pub(super) fn inject_mcp_metadata_streaming( + response: &mut Value, + state: &ToolLoopState, + mcp: &Arc, + server_label: &str, +) { + if let Some(output_array) = response.get_mut("output").and_then(|v| v.as_array_mut()) { + output_array.retain(|item| { + item.get("type").and_then(|t| t.as_str()) != Some(event_types::ITEM_TYPE_MCP_LIST_TOOLS) + }); + + let list_tools_item = build_mcp_list_tools_item(mcp, server_label); + output_array.insert(0, list_tools_item); + + let mcp_call_items = + build_executed_mcp_call_items(&state.conversation_history, server_label); + let mut insert_pos = 1; + for item in mcp_call_items { + output_array.insert(insert_pos, item); + insert_pos += 1; + } + } else if let Some(obj) = response.as_object_mut() { + let mut output_items = Vec::new(); + output_items.push(build_mcp_list_tools_item(mcp, server_label)); + output_items.extend(build_executed_mcp_call_items( + &state.conversation_history, + server_label, + )); + obj.insert("output".to_string(), Value::Array(output_items)); + } +} + +// ============================================================================ +// Tool Loop Execution +// ============================================================================ + +/// Execute the tool calling loop +pub(super) async fn execute_tool_loop( + client: &reqwest::Client, + url: &str, + headers: Option<&HeaderMap>, + initial_payload: Value, + original_body: &ResponsesRequest, + active_mcp: &Arc, + config: &McpLoopConfig, +) -> Result { + let mut state = ToolLoopState::new(original_body.input.clone()); + + // Get max_tool_calls from request (None means no user-specified limit) + let max_tool_calls = original_body.max_tool_calls.map(|n| n as usize); + + // Keep initial_payload as base template (already has fields cleaned) + let base_payload = initial_payload.clone(); + let tools_json = base_payload.get("tools").cloned().unwrap_or(json!([])); + let mut current_payload = initial_payload; + + info!( + "Starting tool loop: max_tool_calls={:?}, max_iterations={}", + max_tool_calls, config.max_iterations + ); + + loop { + // Make request to upstream + let request_builder = client.post(url).json(¤t_payload); + let request_builder = if let Some(headers) = headers { + apply_request_headers(headers, request_builder, true) + } else { + request_builder + }; + + let response = request_builder + .send() + .await + .map_err(|e| format!("upstream request failed: {}", e))?; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + return Err(format!("upstream error {}: {}", status, body)); + } + + let mut response_json = response + .json::() + .await + .map_err(|e| format!("parse response: {}", e))?; + + // Check for function call + if let Some((call_id, tool_name, args_json_str)) = extract_function_call(&response_json) { + state.iteration += 1; + state.total_calls += 1; + + info!( + "Tool loop iteration {}: calling {} (call_id: {})", + state.iteration, tool_name, call_id + ); + + // Check combined limit: use minimum of user's max_tool_calls (if set) and safety max_iterations + let effective_limit = match max_tool_calls { + Some(user_max) => user_max.min(config.max_iterations), + None => config.max_iterations, + }; + + if state.total_calls > effective_limit { + if let Some(user_max) = max_tool_calls { + if state.total_calls > user_max { + warn!("Reached user-specified max_tool_calls limit: {}", user_max); + } else { + warn!( + "Reached safety max_iterations limit: {}", + config.max_iterations + ); + } + } else { + warn!( + "Reached safety max_iterations limit: {}", + config.max_iterations + ); + } + + return build_incomplete_response( + response_json, + state, + "max_tool_calls", + active_mcp, + original_body, + ); + } + + // Execute tool + let call_result = execute_mcp_call(active_mcp, &tool_name, &args_json_str).await; + + let output_str = match call_result { + Ok((_, output)) => output, + Err(err) => { + warn!("Tool execution failed: {}", err); + // Return error as output, let model decide how to proceed + json!({ "error": err }).to_string() + } + }; + + // Record the call + state.record_call(call_id, tool_name, args_json_str, output_str); + + // Build resume payload + current_payload = build_resume_payload( + &base_payload, + &state.conversation_history, + &state.original_input, + &tools_json, + false, // is_streaming = false (non-streaming tool loop) + )?; + } else { + // No more tool calls, we're done + info!( + "Tool loop completed: {} iterations, {} total calls", + state.iteration, state.total_calls + ); + + // Inject MCP output items if we executed any tools + if state.total_calls > 0 { + let server_label = original_body + .tools + .iter() + .find(|t| matches!(t.r#type, ResponseToolType::Mcp)) + .and_then(|t| t.server_label.as_deref()) + .unwrap_or("mcp"); + + // Build mcp_list_tools item + let list_tools_item = build_mcp_list_tools_item(active_mcp, server_label); + + // Insert at beginning of output array + if let Some(output_array) = response_json + .get_mut("output") + .and_then(|v| v.as_array_mut()) + { + output_array.insert(0, list_tools_item); + + // Build mcp_call items using helper function + let mcp_call_items = + build_executed_mcp_call_items(&state.conversation_history, server_label); + + // Insert mcp_call items after mcp_list_tools using mutable position + let mut insert_pos = 1; + for item in mcp_call_items { + output_array.insert(insert_pos, item); + insert_pos += 1; + } + } + } + + return Ok(response_json); + } + } +} + +/// Build an incomplete response when limits are exceeded +pub(super) fn build_incomplete_response( + mut response: Value, + state: ToolLoopState, + reason: &str, + active_mcp: &Arc, + original_body: &ResponsesRequest, +) -> Result { + let obj = response + .as_object_mut() + .ok_or_else(|| "response not an object".to_string())?; + + // Set status to completed (not failed - partial success) + obj.insert("status".to_string(), Value::String("completed".to_string())); + + // Set incomplete_details + obj.insert( + "incomplete_details".to_string(), + json!({ "reason": reason }), + ); + + // Convert any function_call in output to mcp_call format + if let Some(output_array) = obj.get_mut("output").and_then(|v| v.as_array_mut()) { + let server_label = original_body + .tools + .iter() + .find(|t| matches!(t.r#type, ResponseToolType::Mcp)) + .and_then(|t| t.server_label.as_deref()) + .unwrap_or("mcp"); + + // Find any function_call items and convert them to mcp_call (incomplete) + let mut mcp_call_items = Vec::new(); + for item in output_array.iter() { + let item_type = item.get("type").and_then(|t| t.as_str()); + if item_type == Some(event_types::ITEM_TYPE_FUNCTION_TOOL_CALL) + || item_type == Some(event_types::ITEM_TYPE_FUNCTION_CALL) + { + let tool_name = item.get("name").and_then(|v| v.as_str()).unwrap_or(""); + let args = item + .get("arguments") + .and_then(|v| v.as_str()) + .unwrap_or("{}"); + + // Mark as incomplete - not executed + let mcp_call_item = build_mcp_call_item( + tool_name, + args, + "", // No output - wasn't executed + server_label, + false, // Not successful + Some("Not executed - response stopped due to limit"), + ); + mcp_call_items.push(mcp_call_item); + } + } + + // Add mcp_list_tools and executed mcp_call items at the beginning + if state.total_calls > 0 || !mcp_call_items.is_empty() { + let list_tools_item = build_mcp_list_tools_item(active_mcp, server_label); + output_array.insert(0, list_tools_item); + + // Add mcp_call items for executed calls using helper + let executed_items = + build_executed_mcp_call_items(&state.conversation_history, server_label); + + let mut insert_pos = 1; + for item in executed_items { + output_array.insert(insert_pos, item); + insert_pos += 1; + } + + // Add incomplete mcp_call items + for item in mcp_call_items { + output_array.insert(insert_pos, item); + insert_pos += 1; + } + } + } + + // Add warning to metadata + if let Some(metadata_val) = obj.get_mut("metadata") { + if let Some(metadata_obj) = metadata_val.as_object_mut() { + if let Some(mcp_val) = metadata_obj.get_mut("mcp") { + if let Some(mcp_obj) = mcp_val.as_object_mut() { + mcp_obj.insert( + "truncation_warning".to_string(), + Value::String(format!( + "Loop terminated at {} iterations, {} total calls (reason: {})", + state.iteration, state.total_calls, reason + )), + ); + } + } + } + } + + Ok(response) +} + +// ============================================================================ +// Output Item Builders +// ============================================================================ + +/// Generate a unique ID for MCP output items (similar to OpenAI format) +pub(super) fn generate_mcp_id(prefix: &str) -> String { + use rand::RngCore; + let mut rng = rand::rng(); + let mut bytes = [0u8; 30]; + rng.fill_bytes(&mut bytes); + let hex_string: String = bytes.iter().map(|b| format!("{:02x}", b)).collect(); + format!("{}_{}", prefix, hex_string) +} + +/// Build an mcp_list_tools output item +pub(super) fn build_mcp_list_tools_item(mcp: &Arc, server_label: &str) -> Value { + let tools = mcp.list_tools(); + let tools_json: Vec = tools + .iter() + .map(|t| { + json!({ + "name": t.name, + "description": t.description, + "input_schema": t.parameters.clone().unwrap_or_else(|| json!({ + "type": "object", + "properties": {}, + "additionalProperties": false + })), + "annotations": { + "read_only": false + } + }) + }) + .collect(); + + json!({ + "id": generate_mcp_id("mcpl"), + "type": event_types::ITEM_TYPE_MCP_LIST_TOOLS, + "server_label": server_label, + "tools": tools_json + }) +} + +/// Build an mcp_call output item +pub(super) fn build_mcp_call_item( + tool_name: &str, + arguments: &str, + output: &str, + server_label: &str, + success: bool, + error: Option<&str>, +) -> Value { + json!({ + "id": generate_mcp_id("mcp"), + "type": event_types::ITEM_TYPE_MCP_CALL, + "status": if success { "completed" } else { "failed" }, + "approval_request_id": Value::Null, + "arguments": arguments, + "error": error, + "name": tool_name, + "output": output, + "server_label": server_label + }) +} + +/// Helper function to build mcp_call items from executed tool calls in conversation history +pub(super) fn build_executed_mcp_call_items( + conversation_history: &[Value], + server_label: &str, +) -> Vec { + let mut mcp_call_items = Vec::new(); + + for item in conversation_history { + if item.get("type").and_then(|t| t.as_str()) == Some(event_types::ITEM_TYPE_FUNCTION_CALL) { + let call_id = item.get("call_id").and_then(|v| v.as_str()).unwrap_or(""); + let tool_name = item.get("name").and_then(|v| v.as_str()).unwrap_or(""); + let args = item + .get("arguments") + .and_then(|v| v.as_str()) + .unwrap_or("{}"); + + // Find corresponding output + let output_item = conversation_history.iter().find(|o| { + o.get("type").and_then(|t| t.as_str()) == Some("function_call_output") + && o.get("call_id").and_then(|c| c.as_str()) == Some(call_id) + }); + + let output_str = output_item + .and_then(|o| o.get("output").and_then(|v| v.as_str())) + .unwrap_or("{}"); + + // Check if output contains error by parsing JSON + let is_error = serde_json::from_str::(output_str) + .map(|v| v.get("error").is_some()) + .unwrap_or(false); + + let mcp_call_item = build_mcp_call_item( + tool_name, + args, + output_str, + server_label, + !is_error, + if is_error { + Some("Tool execution failed") + } else { + None + }, + ); + mcp_call_items.push(mcp_call_item); + } + } + + mcp_call_items +} + +// ============================================================================ +// Helper Functions +// ============================================================================ + +/// Extract function call from a response +pub(super) fn extract_function_call(resp: &Value) -> Option<(String, String, String)> { + let output = resp.get("output")?.as_array()?; + for item in output { + let obj = item.as_object()?; + let t = obj.get("type")?.as_str()?; + if t == event_types::ITEM_TYPE_FUNCTION_TOOL_CALL + || t == event_types::ITEM_TYPE_FUNCTION_CALL + { + let call_id = obj + .get("call_id") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + .or_else(|| { + obj.get("id") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + })?; + let name = obj.get("name")?.as_str()?.to_string(); + let arguments = obj.get("arguments")?.as_str()?.to_string(); + return Some((call_id, name, arguments)); + } + } + None +} diff --git a/sgl-router/src/routers/openai/mod.rs b/sgl-router/src/routers/openai/mod.rs new file mode 100644 index 000000000..9bb2c4d01 --- /dev/null +++ b/sgl-router/src/routers/openai/mod.rs @@ -0,0 +1,18 @@ +//! OpenAI-compatible router implementation +//! +//! This module provides OpenAI-compatible API routing with support for: +//! - Streaming and non-streaming responses +//! - MCP (Model Context Protocol) tool calling +//! - Response storage and conversation management +//! - Multi-turn tool execution loops +//! - SSE (Server-Sent Events) streaming + +mod conversations; +mod mcp; +mod responses; +mod router; +mod streaming; +mod utils; + +// Re-export the main router type for external use +pub use router::OpenAIRouter; diff --git a/sgl-router/src/routers/openai/responses.rs b/sgl-router/src/routers/openai/responses.rs new file mode 100644 index 000000000..866c6d45e --- /dev/null +++ b/sgl-router/src/routers/openai/responses.rs @@ -0,0 +1,368 @@ +//! Response storage, patching, and extraction utilities + +use crate::data_connector::{ResponseId, SharedResponseStorage, StoredResponse}; +use crate::protocols::spec::{ResponseInput, ResponseToolType, ResponsesRequest}; +use serde_json::{json, Value}; +use std::collections::HashMap; +use tracing::{info, warn}; + +use super::utils::event_types; + +// ============================================================================ +// Response Storage Operations +// ============================================================================ + +/// Store a response internally (checks if storage is enabled) +pub(super) async fn store_response_internal( + response_storage: &SharedResponseStorage, + response_json: &Value, + original_body: &ResponsesRequest, +) -> Result<(), String> { + if !original_body.store { + return Ok(()); + } + + match store_response_impl(response_storage, response_json, original_body).await { + Ok(response_id) => { + info!(response_id = %response_id.0, "Stored response locally"); + Ok(()) + } + Err(e) => Err(e), + } +} + +/// Build a StoredResponse from response JSON and original request +pub(super) fn build_stored_response( + response_json: &Value, + original_body: &ResponsesRequest, +) -> StoredResponse { + let input_text = match &original_body.input { + ResponseInput::Text(text) => text.clone(), + ResponseInput::Items(_) => "complex input".to_string(), + }; + + let output_text = extract_primary_output_text(response_json).unwrap_or_default(); + + let mut stored_response = StoredResponse::new(input_text, output_text, None); + + stored_response.instructions = response_json + .get("instructions") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + .or_else(|| original_body.instructions.clone()); + + stored_response.model = response_json + .get("model") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + .or_else(|| original_body.model.clone()); + + stored_response.user = response_json + .get("user") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + .or_else(|| original_body.user.clone()); + + // Set conversation id from request if provided + if let Some(conv_id) = original_body.conversation.clone() { + stored_response.conversation_id = Some(conv_id); + } + + stored_response.metadata = response_json + .get("metadata") + .and_then(|v| v.as_object()) + .map(|m| { + m.iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect::>() + }) + .unwrap_or_else(|| original_body.metadata.clone().unwrap_or_default()); + + stored_response.previous_response_id = response_json + .get("previous_response_id") + .and_then(|v| v.as_str()) + .map(ResponseId::from) + .or_else(|| { + original_body + .previous_response_id + .as_ref() + .map(|id| ResponseId::from(id.as_str())) + }); + + if let Some(id_str) = response_json.get("id").and_then(|v| v.as_str()) { + stored_response.id = ResponseId::from(id_str); + } + + stored_response.raw_response = response_json.clone(); + + stored_response +} + +/// Store response implementation (public for use across modules) +pub(super) async fn store_response_impl( + response_storage: &SharedResponseStorage, + response_json: &Value, + original_body: &ResponsesRequest, +) -> Result { + let stored_response = build_stored_response(response_json, original_body); + + response_storage + .store_response(stored_response) + .await + .map_err(|e| format!("Failed to store response: {}", e)) +} + +// ============================================================================ +// Response JSON Patching +// ============================================================================ + +/// Patch streaming response JSON with metadata from original request +pub(super) fn patch_streaming_response_json( + response_json: &mut Value, + original_body: &ResponsesRequest, + original_previous_response_id: Option<&str>, +) { + if let Some(obj) = response_json.as_object_mut() { + if let Some(prev_id) = original_previous_response_id { + let should_insert = obj + .get("previous_response_id") + .map(|v| v.is_null() || v.as_str().map(|s| s.is_empty()).unwrap_or(false)) + .unwrap_or(true); + if should_insert { + obj.insert( + "previous_response_id".to_string(), + Value::String(prev_id.to_string()), + ); + } + } + + if !obj.contains_key("instructions") + || obj + .get("instructions") + .map(|v| v.is_null()) + .unwrap_or(false) + { + if let Some(instructions) = &original_body.instructions { + obj.insert( + "instructions".to_string(), + Value::String(instructions.clone()), + ); + } + } + + if !obj.contains_key("metadata") + || obj.get("metadata").map(|v| v.is_null()).unwrap_or(false) + { + if let Some(metadata) = &original_body.metadata { + let metadata_map: serde_json::Map = metadata + .iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect(); + obj.insert("metadata".to_string(), Value::Object(metadata_map)); + } + } + + obj.insert("store".to_string(), Value::Bool(original_body.store)); + + if obj + .get("model") + .and_then(|v| v.as_str()) + .map(|s| s.is_empty()) + .unwrap_or(true) + { + if let Some(model) = &original_body.model { + obj.insert("model".to_string(), Value::String(model.clone())); + } + } + + if obj.get("user").map(|v| v.is_null()).unwrap_or(false) { + if let Some(user) = &original_body.user { + obj.insert("user".to_string(), Value::String(user.clone())); + } + } + + // Attach conversation id for client response if present (final aggregated JSON) + if let Some(conv_id) = original_body.conversation.clone() { + obj.insert("conversation".to_string(), json!({"id": conv_id})); + } + } +} + +/// Rewrite streaming SSE block to include metadata from original request +pub(super) fn rewrite_streaming_block( + block: &str, + original_body: &ResponsesRequest, + original_previous_response_id: Option<&str>, +) -> Option { + let trimmed = block.trim(); + if trimmed.is_empty() { + return None; + } + + let mut data_lines: Vec = Vec::new(); + + for line in trimmed.lines() { + if line.starts_with("data:") { + data_lines.push(line.trim_start_matches("data:").trim_start().to_string()); + } + } + + if data_lines.is_empty() { + return None; + } + + let payload = data_lines.join("\n"); + let mut parsed: Value = match serde_json::from_str(&payload) { + Ok(value) => value, + Err(err) => { + warn!("Failed to parse streaming JSON payload: {}", err); + return None; + } + }; + + let event_type = parsed + .get("type") + .and_then(|v| v.as_str()) + .unwrap_or_default(); + + let should_patch = matches!( + event_type, + event_types::RESPONSE_CREATED + | event_types::RESPONSE_IN_PROGRESS + | event_types::RESPONSE_COMPLETED + ); + + if !should_patch { + return None; + } + + let mut changed = false; + if let Some(response_obj) = parsed.get_mut("response").and_then(|v| v.as_object_mut()) { + let desired_store = Value::Bool(original_body.store); + if response_obj.get("store") != Some(&desired_store) { + response_obj.insert("store".to_string(), desired_store); + changed = true; + } + + if let Some(prev_id) = original_previous_response_id { + let needs_previous = response_obj + .get("previous_response_id") + .map(|v| v.is_null() || v.as_str().map(|s| s.is_empty()).unwrap_or(false)) + .unwrap_or(true); + + if needs_previous { + response_obj.insert( + "previous_response_id".to_string(), + Value::String(prev_id.to_string()), + ); + changed = true; + } + } + + // Attach conversation id into streaming event response content with ordering + if let Some(conv_id) = original_body.conversation.clone() { + response_obj.insert("conversation".to_string(), json!({"id": conv_id})); + changed = true; + } + } + + if !changed { + return None; + } + + let new_payload = match serde_json::to_string(&parsed) { + Ok(json) => json, + Err(err) => { + warn!("Failed to serialize modified streaming payload: {}", err); + return None; + } + }; + + let mut rebuilt_lines = Vec::new(); + let mut data_written = false; + for line in trimmed.lines() { + if line.starts_with("data:") { + if !data_written { + rebuilt_lines.push(format!("data: {}", new_payload)); + data_written = true; + } + } else { + rebuilt_lines.push(line.to_string()); + } + } + + if !data_written { + rebuilt_lines.push(format!("data: {}", new_payload)); + } + + Some(rebuilt_lines.join("\n")) +} + +/// Mask function tools as MCP tools in response for client +pub(super) fn mask_tools_as_mcp(resp: &mut Value, original_body: &ResponsesRequest) { + let mcp_tool = original_body + .tools + .iter() + .find(|t| matches!(t.r#type, ResponseToolType::Mcp) && t.server_url.is_some()); + let Some(t) = mcp_tool else { + return; + }; + + let mut m = serde_json::Map::new(); + m.insert("type".to_string(), Value::String("mcp".to_string())); + if let Some(label) = &t.server_label { + m.insert("server_label".to_string(), Value::String(label.clone())); + } + if let Some(url) = &t.server_url { + m.insert("server_url".to_string(), Value::String(url.clone())); + } + if let Some(desc) = &t.server_description { + m.insert( + "server_description".to_string(), + Value::String(desc.clone()), + ); + } + if let Some(req) = &t.require_approval { + m.insert("require_approval".to_string(), Value::String(req.clone())); + } + if let Some(allowed) = &t.allowed_tools { + m.insert( + "allowed_tools".to_string(), + Value::Array(allowed.iter().map(|s| Value::String(s.clone())).collect()), + ); + } + + if let Some(obj) = resp.as_object_mut() { + obj.insert("tools".to_string(), Value::Array(vec![Value::Object(m)])); + obj.entry("tool_choice") + .or_insert(Value::String("auto".to_string())); + } +} + +// ============================================================================ +// Output Text Extraction +// ============================================================================ + +/// Extract primary output text from response JSON +pub(super) fn extract_primary_output_text(response_json: &Value) -> Option { + if let Some(items) = response_json.get("output").and_then(|v| v.as_array()) { + for item in items { + if let Some(content) = item.get("content").and_then(|v| v.as_array()) { + for part in content { + if part + .get("type") + .and_then(|v| v.as_str()) + .map(|t| t == "output_text") + .unwrap_or(false) + { + if let Some(text) = part.get("text").and_then(|v| v.as_str()) { + return Some(text.to_string()); + } + } + } + } + } + } + + None +} diff --git a/sgl-router/src/routers/openai/router.rs b/sgl-router/src/routers/openai/router.rs new file mode 100644 index 000000000..478a7301e --- /dev/null +++ b/sgl-router/src/routers/openai/router.rs @@ -0,0 +1,909 @@ +//! OpenAI router - main coordinator that delegates to specialized modules + +use crate::config::CircuitBreakerConfig; +use crate::core::{CircuitBreaker, CircuitBreakerConfig as CoreCircuitBreakerConfig}; +use crate::data_connector::{ + conversation_items::ListParams, conversation_items::SortOrder, ConversationId, ResponseId, + SharedConversationItemStorage, SharedConversationStorage, SharedResponseStorage, +}; +use crate::protocols::spec::{ + ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, + ResponseContentPart, ResponseInput, ResponseInputOutputItem, ResponsesGetParams, + ResponsesRequest, +}; +use crate::routers::header_utils::apply_request_headers; +use axum::{ + body::Body, + extract::Request, + http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode}, + response::{IntoResponse, Response}, + Json, +}; +use futures_util::StreamExt; +use serde_json::{json, to_value, Value}; +use std::{ + any::Any, + sync::{atomic::AtomicBool, Arc}, +}; +use tokio::sync::mpsc; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tracing::{info, warn}; + +// Import from sibling modules +use super::conversations::{ + create_conversation, delete_conversation, get_conversation, list_conversation_items, + persist_conversation_items, update_conversation, +}; +use super::mcp::{ + execute_tool_loop, mcp_manager_from_request_tools, prepare_mcp_payload_for_streaming, + McpLoopConfig, +}; +use super::responses::{mask_tools_as_mcp, patch_streaming_response_json, store_response_internal}; +use super::streaming::handle_streaming_response; + +// ============================================================================ +// OpenAIRouter Struct +// ============================================================================ + +/// Router for OpenAI backend +pub struct OpenAIRouter { + /// HTTP client for upstream OpenAI-compatible API + client: reqwest::Client, + /// Base URL for identification (no trailing slash) + base_url: String, + /// Circuit breaker + circuit_breaker: CircuitBreaker, + /// Health status + healthy: AtomicBool, + /// Response storage for managing conversation history + response_storage: SharedResponseStorage, + /// Conversation storage backend + conversation_storage: SharedConversationStorage, + /// Conversation item storage backend + conversation_item_storage: SharedConversationItemStorage, + /// Optional MCP manager (enabled via config presence) + mcp_manager: Option>, +} + +impl std::fmt::Debug for OpenAIRouter { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("OpenAIRouter") + .field("base_url", &self.base_url) + .field("healthy", &self.healthy) + .finish() + } +} + +impl OpenAIRouter { + /// Maximum number of conversation items to attach as input when a conversation is provided + const MAX_CONVERSATION_HISTORY_ITEMS: usize = 100; + + /// Create a new OpenAI router + pub async fn new( + base_url: String, + circuit_breaker_config: Option, + response_storage: SharedResponseStorage, + conversation_storage: SharedConversationStorage, + conversation_item_storage: SharedConversationItemStorage, + ) -> Result { + let client = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(300)) + .build() + .map_err(|e| format!("Failed to create HTTP client: {}", e))?; + + let base_url = base_url.trim_end_matches('/').to_string(); + + // Convert circuit breaker config + let core_cb_config = circuit_breaker_config + .map(|cb| CoreCircuitBreakerConfig { + failure_threshold: cb.failure_threshold, + success_threshold: cb.success_threshold, + timeout_duration: std::time::Duration::from_secs(cb.timeout_duration_secs), + window_duration: std::time::Duration::from_secs(cb.window_duration_secs), + }) + .unwrap_or_default(); + + let circuit_breaker = CircuitBreaker::with_config(core_cb_config); + + // Optional MCP manager activation via env var path (config-driven gate) + let mcp_manager = match std::env::var("SGLANG_MCP_CONFIG").ok() { + Some(path) if !path.trim().is_empty() => { + match crate::mcp::McpConfig::from_file(&path).await { + Ok(cfg) => match crate::mcp::McpClientManager::new(cfg).await { + Ok(mgr) => Some(Arc::new(mgr)), + Err(err) => { + warn!("Failed to initialize MCP manager: {}", err); + None + } + }, + Err(err) => { + warn!("Failed to load MCP config from '{}': {}", path, err); + None + } + } + } + _ => None, + }; + + Ok(Self { + client, + base_url, + circuit_breaker, + healthy: AtomicBool::new(true), + response_storage, + conversation_storage, + conversation_item_storage, + mcp_manager, + }) + } + + /// Handle non-streaming response with optional MCP tool loop + async fn handle_non_streaming_response( + &self, + url: String, + headers: Option<&HeaderMap>, + mut payload: Value, + original_body: &ResponsesRequest, + original_previous_response_id: Option, + ) -> Response { + // Check if MCP is active for this request + let req_mcp_manager = mcp_manager_from_request_tools(&original_body.tools).await; + let active_mcp = req_mcp_manager.as_ref().or(self.mcp_manager.as_ref()); + + let mut response_json: Value; + + // If MCP is active, execute tool loop + if let Some(mcp) = active_mcp { + let config = McpLoopConfig::default(); + + // Transform MCP tools to function tools + prepare_mcp_payload_for_streaming(&mut payload, mcp); + + match execute_tool_loop( + &self.client, + &url, + headers, + payload, + original_body, + mcp, + &config, + ) + .await + { + Ok(resp) => response_json = resp, + Err(err) => { + self.circuit_breaker.record_failure(); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": {"message": err}})), + ) + .into_response(); + } + } + } else { + // No MCP - simple request + let mut request_builder = self.client.post(&url).json(&payload); + if let Some(h) = headers { + request_builder = apply_request_headers(h, request_builder, true); + } + + let response = match request_builder.send().await { + Ok(r) => r, + Err(e) => { + self.circuit_breaker.record_failure(); + return ( + StatusCode::BAD_GATEWAY, + format!("Failed to forward request to OpenAI: {}", e), + ) + .into_response(); + } + }; + + if !response.status().is_success() { + self.circuit_breaker.record_failure(); + let status = StatusCode::from_u16(response.status().as_u16()) + .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); + let body = response.text().await.unwrap_or_default(); + return (status, body).into_response(); + } + + response_json = match response.json::().await { + Ok(r) => r, + Err(e) => { + self.circuit_breaker.record_failure(); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Failed to parse upstream response: {}", e), + ) + .into_response(); + } + }; + + self.circuit_breaker.record_success(); + } + + // Patch response with metadata + mask_tools_as_mcp(&mut response_json, original_body); + patch_streaming_response_json( + &mut response_json, + original_body, + original_previous_response_id.as_deref(), + ); + + // Persist conversation items if conversation is provided + if original_body.conversation.is_some() { + if let Err(err) = persist_conversation_items( + self.conversation_storage.clone(), + self.conversation_item_storage.clone(), + self.response_storage.clone(), + &response_json, + original_body, + ) + .await + { + warn!("Failed to persist conversation items: {}", err); + } + } else { + // Store response only if no conversation (persist_conversation_items already stores it) + if let Err(err) = + store_response_internal(&self.response_storage, &response_json, original_body).await + { + warn!("Failed to store response: {}", err); + } + } + + (StatusCode::OK, Json(response_json)).into_response() + } +} + +// ============================================================================ +// RouterTrait Implementation +// ============================================================================ + +#[async_trait::async_trait] +impl crate::routers::RouterTrait for OpenAIRouter { + fn as_any(&self) -> &dyn Any { + self + } + + async fn health_generate(&self, _req: Request) -> Response { + // Simple upstream probe: GET {base}/v1/models without auth + let url = format!("{}/v1/models", self.base_url); + match self + .client + .get(&url) + .timeout(std::time::Duration::from_secs(2)) + .send() + .await + { + Ok(resp) => { + let code = resp.status(); + // Treat success and auth-required as healthy (endpoint reachable) + if code.is_success() || code.as_u16() == 401 || code.as_u16() == 403 { + (StatusCode::OK, "OK").into_response() + } else { + ( + StatusCode::SERVICE_UNAVAILABLE, + format!("Upstream status: {}", code), + ) + .into_response() + } + } + Err(e) => ( + StatusCode::SERVICE_UNAVAILABLE, + format!("Upstream error: {}", e), + ) + .into_response(), + } + } + + async fn get_server_info(&self, _req: Request) -> Response { + let info = json!({ + "router_type": "openai", + "workers": 1, + "base_url": &self.base_url + }); + (StatusCode::OK, info.to_string()).into_response() + } + + async fn get_models(&self, req: Request) -> Response { + // Proxy to upstream /v1/models; forward Authorization header if provided + let headers = req.headers(); + + let mut upstream = self.client.get(format!("{}/v1/models", self.base_url)); + + if let Some(auth) = headers + .get("authorization") + .or_else(|| headers.get("Authorization")) + { + upstream = upstream.header("Authorization", auth); + } + + match upstream.send().await { + Ok(res) => { + let status = StatusCode::from_u16(res.status().as_u16()) + .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); + let content_type = res.headers().get(CONTENT_TYPE).cloned(); + match res.bytes().await { + Ok(body) => { + let mut response = Response::new(Body::from(body)); + *response.status_mut() = status; + if let Some(ct) = content_type { + response.headers_mut().insert(CONTENT_TYPE, ct); + } + response + } + Err(e) => ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Failed to read upstream response: {}", e), + ) + .into_response(), + } + } + Err(e) => ( + StatusCode::BAD_GATEWAY, + format!("Failed to contact upstream: {}", e), + ) + .into_response(), + } + } + + async fn get_model_info(&self, _req: Request) -> Response { + // Not directly supported without model param; return 501 + ( + StatusCode::NOT_IMPLEMENTED, + "get_model_info not implemented for OpenAI router", + ) + .into_response() + } + + async fn route_generate( + &self, + _headers: Option<&HeaderMap>, + _body: &GenerateRequest, + _model_id: Option<&str>, + ) -> Response { + // Generate endpoint is SGLang-specific, not supported for OpenAI backend + ( + StatusCode::NOT_IMPLEMENTED, + "Generate endpoint not supported for OpenAI backend", + ) + .into_response() + } + + async fn route_chat( + &self, + headers: Option<&HeaderMap>, + body: &ChatCompletionRequest, + _model_id: Option<&str>, + ) -> Response { + if !self.circuit_breaker.can_execute() { + return (StatusCode::SERVICE_UNAVAILABLE, "Circuit breaker open").into_response(); + } + + // Serialize request body, removing SGLang-only fields + let mut payload = match to_value(body) { + Ok(v) => v, + Err(e) => { + return ( + StatusCode::BAD_REQUEST, + format!("Failed to serialize request: {}", e), + ) + .into_response(); + } + }; + if let Some(obj) = payload.as_object_mut() { + for key in [ + "top_k", + "min_p", + "min_tokens", + "regex", + "ebnf", + "stop_token_ids", + "no_stop_trim", + "ignore_eos", + "continue_final_message", + "skip_special_tokens", + "lora_path", + "session_params", + "separate_reasoning", + "stream_reasoning", + "chat_template_kwargs", + "return_hidden_states", + "repetition_penalty", + "sampling_seed", + ] { + obj.remove(key); + } + } + + let url = format!("{}/v1/chat/completions", self.base_url); + let mut req = self.client.post(&url).json(&payload); + + // Forward Authorization header if provided + if let Some(h) = headers { + if let Some(auth) = h.get("authorization").or_else(|| h.get("Authorization")) { + req = req.header("Authorization", auth); + } + } + + // Accept SSE when stream=true + if body.stream { + req = req.header("Accept", "text/event-stream"); + } + + let resp = match req.send().await { + Ok(r) => r, + Err(e) => { + self.circuit_breaker.record_failure(); + return ( + StatusCode::SERVICE_UNAVAILABLE, + format!("Failed to contact upstream: {}", e), + ) + .into_response(); + } + }; + + let status = StatusCode::from_u16(resp.status().as_u16()) + .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); + + if !body.stream { + // Capture Content-Type before consuming response body + let content_type = resp.headers().get(CONTENT_TYPE).cloned(); + match resp.bytes().await { + Ok(body) => { + self.circuit_breaker.record_success(); + let mut response = Response::new(Body::from(body)); + *response.status_mut() = status; + if let Some(ct) = content_type { + response.headers_mut().insert(CONTENT_TYPE, ct); + } + response + } + Err(e) => { + self.circuit_breaker.record_failure(); + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Failed to read response: {}", e), + ) + .into_response() + } + } + } else { + // Stream SSE bytes to client + let stream = resp.bytes_stream(); + let (tx, rx) = mpsc::unbounded_channel(); + tokio::spawn(async move { + let mut s = stream; + while let Some(chunk) = s.next().await { + match chunk { + Ok(bytes) => { + if tx.send(Ok(bytes)).is_err() { + break; + } + } + Err(e) => { + let _ = tx.send(Err(format!("Stream error: {}", e))); + break; + } + } + } + }); + let mut response = Response::new(Body::from_stream(UnboundedReceiverStream::new(rx))); + *response.status_mut() = status; + response + .headers_mut() + .insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream")); + response + } + } + + async fn route_completion( + &self, + _headers: Option<&HeaderMap>, + _body: &CompletionRequest, + _model_id: Option<&str>, + ) -> Response { + // Completion endpoint not implemented for OpenAI backend + ( + StatusCode::NOT_IMPLEMENTED, + "Completion endpoint not implemented for OpenAI backend", + ) + .into_response() + } + + async fn route_responses( + &self, + headers: Option<&HeaderMap>, + body: &ResponsesRequest, + model_id: Option<&str>, + ) -> Response { + let url = format!("{}/v1/responses", self.base_url); + + info!( + requested_store = body.store, + is_streaming = body.stream, + "openai_responses_request" + ); + + // Validate mutually exclusive params: previous_response_id and conversation + // TODO: this validation logic should move the right place, also we need a proper error message module + if body.previous_response_id.is_some() && body.conversation.is_some() { + return ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": { + "message": "Mutually exclusive parameters. Ensure you are only providing one of: 'previous_response_id' or 'conversation'.", + "type": "invalid_request_error", + "param": Value::Null, + "code": "mutually_exclusive_parameters" + } + })), + ) + .into_response(); + } + + // Clone the body and override model if needed + let mut request_body = body.clone(); + if let Some(model) = model_id { + request_body.model = Some(model.to_string()); + } + // Do not forward conversation field upstream; retain for local persistence only + request_body.conversation = None; + + // Store the original previous_response_id for the response + let original_previous_response_id = request_body.previous_response_id.clone(); + + // Handle previous_response_id by loading prior context + let mut conversation_items: Option> = None; + if let Some(prev_id_str) = request_body.previous_response_id.clone() { + let prev_id = ResponseId::from(prev_id_str.as_str()); + match self + .response_storage + .get_response_chain(&prev_id, None) + .await + { + Ok(chain) => { + let mut items = Vec::new(); + for stored in chain.responses.iter() { + // Convert input to conversation item + items.push(ResponseInputOutputItem::Message { + id: format!("msg_u_{}", stored.id.0.trim_start_matches("resp_")), + role: "user".to_string(), + content: vec![ResponseContentPart::InputText { + text: stored.input.clone(), + }], + status: Some("completed".to_string()), + }); + + // Convert output to conversation items directly from stored response + if let Some(output_arr) = + stored.raw_response.get("output").and_then(|v| v.as_array()) + { + for item in output_arr { + if let Ok(output_item) = + serde_json::from_value::(item.clone()) + { + items.push(output_item); + } + } + } + } + conversation_items = Some(items); + request_body.previous_response_id = None; + } + Err(e) => { + warn!( + "Failed to load previous response chain for {}: {}", + prev_id_str, e + ); + } + } + } + + // Handle conversation by loading history + if let Some(conv_id_str) = body.conversation.clone() { + let conv_id = ConversationId::from(conv_id_str.as_str()); + + // Verify conversation exists + if let Ok(None) = self.conversation_storage.get_conversation(&conv_id).await { + return ( + StatusCode::NOT_FOUND, + Json(json!({"error": "Conversation not found"})), + ) + .into_response(); + } + + // Load conversation history (ascending order for chronological context) + let params = ListParams { + limit: Self::MAX_CONVERSATION_HISTORY_ITEMS, + order: SortOrder::Asc, + after: None, + }; + + match self + .conversation_item_storage + .list_items(&conv_id, params) + .await + { + Ok(stored_items) => { + let mut items: Vec = Vec::new(); + for item in stored_items.into_iter() { + // Only use message items for conversation context + // Skip non-message items (reasoning, function calls, etc.) + if item.item_type == "message" { + if let Ok(content_parts) = + serde_json::from_value::>( + item.content.clone(), + ) + { + items.push(ResponseInputOutputItem::Message { + id: item.id.0.clone(), + role: item.role.clone().unwrap_or_else(|| "user".to_string()), + content: content_parts, + status: item.status.clone(), + }); + } + } + } + + // Append current request + match &request_body.input { + ResponseInput::Text(text) => { + items.push(ResponseInputOutputItem::Message { + id: format!("msg_u_{}", conv_id.0), + role: "user".to_string(), + content: vec![ResponseContentPart::InputText { + text: text.clone(), + }], + status: Some("completed".to_string()), + }); + } + ResponseInput::Items(current_items) => { + items.extend_from_slice(current_items); + } + } + + request_body.input = ResponseInput::Items(items); + } + Err(e) => { + warn!("Failed to load conversation history: {}", e); + } + } + } + + // If we have conversation_items from previous_response_id, use them + if let Some(mut items) = conversation_items { + // Append current request + match &request_body.input { + ResponseInput::Text(text) => { + items.push(ResponseInputOutputItem::Message { + id: format!( + "msg_u_{}", + original_previous_response_id + .as_ref() + .unwrap_or(&"new".to_string()) + ), + role: "user".to_string(), + content: vec![ResponseContentPart::InputText { text: text.clone() }], + status: Some("completed".to_string()), + }); + } + ResponseInput::Items(current_items) => { + items.extend_from_slice(current_items); + } + } + + request_body.input = ResponseInput::Items(items); + } + + // Always set store=false for upstream (we store internally) + request_body.store = false; + + // Convert to JSON and strip SGLang-specific fields + let mut payload = match to_value(&request_body) { + Ok(v) => v, + Err(e) => { + return ( + StatusCode::BAD_REQUEST, + format!("Failed to serialize request: {}", e), + ) + .into_response(); + } + }; + + // Remove SGLang-specific fields + if let Some(obj) = payload.as_object_mut() { + for key in [ + "request_id", + "priority", + "top_k", + "frequency_penalty", + "presence_penalty", + "min_p", + "min_tokens", + "regex", + "ebnf", + "stop_token_ids", + "no_stop_trim", + "ignore_eos", + "continue_final_message", + "skip_special_tokens", + "lora_path", + "session_params", + "separate_reasoning", + "stream_reasoning", + "chat_template_kwargs", + "return_hidden_states", + "repetition_penalty", + "sampling_seed", + ] { + obj.remove(key); + } + } + + // Delegate to streaming or non-streaming handler + if body.stream { + handle_streaming_response( + &self.client, + &self.circuit_breaker, + self.mcp_manager.as_ref(), + self.response_storage.clone(), + self.conversation_storage.clone(), + self.conversation_item_storage.clone(), + url, + headers, + payload, + body, + original_previous_response_id, + ) + .await + } else { + self.handle_non_streaming_response( + url, + headers, + payload, + body, + original_previous_response_id, + ) + .await + } + } + + async fn get_response( + &self, + _headers: Option<&HeaderMap>, + response_id: &str, + _params: &ResponsesGetParams, + ) -> Response { + let id = ResponseId::from(response_id); + match self.response_storage.get_response(&id).await { + Ok(Some(stored)) => { + let mut response_json = stored.raw_response; + if let Some(obj) = response_json.as_object_mut() { + obj.insert("id".to_string(), json!(id.0)); + } + (StatusCode::OK, Json(response_json)).into_response() + } + Ok(None) => ( + StatusCode::NOT_FOUND, + Json(json!({"error": "Response not found"})), + ) + .into_response(), + Err(e) => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": format!("Failed to get response: {}", e)})), + ) + .into_response(), + } + } + + async fn cancel_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response { + // Forward cancellation to upstream + let url = format!("{}/v1/responses/{}/cancel", self.base_url, response_id); + let mut req = self.client.post(&url); + + if let Some(h) = headers { + req = apply_request_headers(h, req, false); + } + + match req.send().await { + Ok(resp) => { + let status = StatusCode::from_u16(resp.status().as_u16()) + .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); + match resp.text().await { + Ok(body) => (status, body).into_response(), + Err(e) => ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Failed to read response: {}", e), + ) + .into_response(), + } + } + Err(e) => ( + StatusCode::BAD_GATEWAY, + format!("Failed to contact upstream: {}", e), + ) + .into_response(), + } + } + + async fn route_embeddings( + &self, + _headers: Option<&HeaderMap>, + _body: &EmbeddingRequest, + _model_id: Option<&str>, + ) -> Response { + (StatusCode::NOT_IMPLEMENTED, "Embeddings not supported").into_response() + } + + async fn route_rerank( + &self, + _headers: Option<&HeaderMap>, + _body: &RerankRequest, + _model_id: Option<&str>, + ) -> Response { + (StatusCode::NOT_IMPLEMENTED, "Rerank not supported").into_response() + } + + async fn create_conversation(&self, _headers: Option<&HeaderMap>, body: &Value) -> Response { + create_conversation(&self.conversation_storage, body.clone()).await + } + + async fn get_conversation( + &self, + _headers: Option<&HeaderMap>, + conversation_id: &str, + ) -> Response { + get_conversation(&self.conversation_storage, conversation_id).await + } + + async fn update_conversation( + &self, + _headers: Option<&HeaderMap>, + conversation_id: &str, + body: &Value, + ) -> Response { + update_conversation(&self.conversation_storage, conversation_id, body.clone()).await + } + + async fn delete_conversation( + &self, + _headers: Option<&HeaderMap>, + conversation_id: &str, + ) -> Response { + delete_conversation(&self.conversation_storage, conversation_id).await + } + + fn router_type(&self) -> &'static str { + "openai" + } + + async fn list_conversation_items( + &self, + _headers: Option<&HeaderMap>, + conversation_id: &str, + limit: Option, + order: Option, + after: Option, + ) -> Response { + let mut query_params = std::collections::HashMap::new(); + query_params.insert("limit".to_string(), limit.unwrap_or(100).to_string()); + if let Some(after_val) = after { + if !after_val.is_empty() { + query_params.insert("after".to_string(), after_val); + } + } + if let Some(order_val) = order { + query_params.insert("order".to_string(), order_val); + } + + list_conversation_items( + &self.conversation_storage, + &self.conversation_item_storage, + conversation_id, + query_params, + ) + .await + } +} diff --git a/sgl-router/src/routers/openai/streaming.rs b/sgl-router/src/routers/openai/streaming.rs new file mode 100644 index 000000000..a60c57949 --- /dev/null +++ b/sgl-router/src/routers/openai/streaming.rs @@ -0,0 +1,1550 @@ +//! Streaming response handling for OpenAI-compatible responses +//! +//! This module handles all streaming-related functionality including: +//! - SSE (Server-Sent Events) parsing and forwarding +//! - Streaming response accumulation for persistence +//! - Tool call detection and interception during streaming +//! - MCP tool execution loops within streaming responses +//! - Event transformation and output index remapping + +use crate::data_connector::{ + SharedConversationItemStorage, SharedConversationStorage, SharedResponseStorage, +}; +use crate::protocols::spec::{ResponseToolType, ResponsesRequest}; +use crate::routers::header_utils::{apply_request_headers, preserve_response_headers}; +use axum::{ + body::Body, + http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode}, + response::{IntoResponse, Response}, +}; +use bytes::Bytes; +use futures_util::StreamExt; +use serde_json::{json, Value}; +use std::{borrow::Cow, io, sync::Arc}; +use tokio::sync::mpsc; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tracing::warn; + +// Import from sibling modules +use super::conversations::persist_conversation_items; +use super::mcp::{ + build_resume_payload, execute_streaming_tool_calls, inject_mcp_metadata_streaming, + mcp_manager_from_request_tools, prepare_mcp_payload_for_streaming, send_mcp_list_tools_events, + McpLoopConfig, ToolLoopState, +}; +use super::responses::{ + mask_tools_as_mcp, patch_streaming_response_json, rewrite_streaming_block, store_response_impl, +}; +use super::utils::{event_types, FunctionCallInProgress, OutputIndexMapper, StreamAction}; + +// ============================================================================ +// Streaming Response Accumulator +// ============================================================================ + +/// Helper that parses SSE frames from the OpenAI responses stream and +/// accumulates enough information to persist the final response locally. +pub(super) struct StreamingResponseAccumulator { + /// The initial `response.created` payload (if emitted). + initial_response: Option, + /// The final `response.completed` payload (if emitted). + completed_response: Option, + /// Collected output items keyed by the upstream output index, used when + /// a final response payload is absent and we need to synthesize one. + output_items: Vec<(usize, Value)>, + /// Captured error payload (if the upstream stream fails midway). + encountered_error: Option, +} + +impl StreamingResponseAccumulator { + pub fn new() -> Self { + Self { + initial_response: None, + completed_response: None, + output_items: Vec::new(), + encountered_error: None, + } + } + + /// Feed the accumulator with the next SSE chunk. + pub fn ingest_block(&mut self, block: &str) { + if block.trim().is_empty() { + return; + } + self.process_block(block); + } + + /// Consume the accumulator and produce the best-effort final response value. + pub fn into_final_response(mut self) -> Option { + if self.completed_response.is_some() { + return self.completed_response; + } + + self.build_fallback_response() + } + + pub fn encountered_error(&self) -> Option<&Value> { + self.encountered_error.as_ref() + } + + pub fn original_response_id(&self) -> Option<&str> { + self.initial_response + .as_ref() + .and_then(|response| response.get("id")) + .and_then(|id| id.as_str()) + } + + pub fn snapshot_final_response(&self) -> Option { + if let Some(resp) = &self.completed_response { + return Some(resp.clone()); + } + self.build_fallback_response_snapshot() + } + + fn build_fallback_response_snapshot(&self) -> Option { + let mut response = self.initial_response.clone()?; + + if let Some(obj) = response.as_object_mut() { + obj.insert("status".to_string(), Value::String("completed".to_string())); + + let mut output_items = self.output_items.clone(); + output_items.sort_by_key(|(index, _)| *index); + let outputs: Vec = output_items.into_iter().map(|(_, item)| item).collect(); + obj.insert("output".to_string(), Value::Array(outputs)); + } + + Some(response) + } + + fn process_block(&mut self, block: &str) { + let trimmed = block.trim(); + if trimmed.is_empty() { + return; + } + + let mut event_name: Option = None; + let mut data_lines: Vec = Vec::new(); + + for line in trimmed.lines() { + if let Some(rest) = line.strip_prefix("event:") { + event_name = Some(rest.trim().to_string()); + } else if let Some(rest) = line.strip_prefix("data:") { + data_lines.push(rest.trim_start().to_string()); + } + } + + let data_payload = data_lines.join("\n"); + if data_payload.is_empty() { + return; + } + + self.handle_event(event_name.as_deref(), &data_payload); + } + + fn handle_event(&mut self, event_name: Option<&str>, data_payload: &str) { + let parsed: Value = match serde_json::from_str(data_payload) { + Ok(value) => value, + Err(err) => { + warn!("Failed to parse streaming event JSON: {}", err); + return; + } + }; + + let event_type = event_name + .map(|s| s.to_string()) + .or_else(|| { + parsed + .get("type") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + }) + .unwrap_or_default(); + + match event_type.as_str() { + event_types::RESPONSE_CREATED => { + if self.initial_response.is_none() { + if let Some(response) = parsed.get("response") { + self.initial_response = Some(response.clone()); + } + } + } + event_types::RESPONSE_COMPLETED => { + if let Some(response) = parsed.get("response") { + self.completed_response = Some(response.clone()); + } + } + event_types::OUTPUT_ITEM_DONE => { + if let (Some(index), Some(item)) = ( + parsed + .get("output_index") + .and_then(|v| v.as_u64()) + .map(|v| v as usize), + parsed.get("item"), + ) { + self.output_items.push((index, item.clone())); + } + } + "response.error" => { + self.encountered_error = Some(parsed); + } + _ => {} + } + } + + fn build_fallback_response(&mut self) -> Option { + let mut response = self.initial_response.clone()?; + + if let Some(obj) = response.as_object_mut() { + obj.insert("status".to_string(), Value::String("completed".to_string())); + + self.output_items.sort_by_key(|(index, _)| *index); + let outputs: Vec = self + .output_items + .iter() + .map(|(_, item)| item.clone()) + .collect(); + obj.insert("output".to_string(), Value::Array(outputs)); + } + + Some(response) + } +} + +// ============================================================================ +// Streaming Tool Handler +// ============================================================================ + +/// Handles streaming responses with MCP tool call interception +pub(super) struct StreamingToolHandler { + /// Accumulator for response persistence + pub accumulator: StreamingResponseAccumulator, + /// Function calls being built from deltas + pub pending_calls: Vec, + /// Track if we're currently in a function call + in_function_call: bool, + /// Manage output_index remapping so they increment per item + output_index_mapper: OutputIndexMapper, + /// Original response id captured from the first response.created event + pub original_response_id: Option, +} + +impl StreamingToolHandler { + pub fn with_starting_index(start: usize) -> Self { + Self { + accumulator: StreamingResponseAccumulator::new(), + pending_calls: Vec::new(), + in_function_call: false, + output_index_mapper: OutputIndexMapper::with_start(start), + original_response_id: None, + } + } + + pub fn ensure_output_index(&mut self, upstream_index: usize) -> usize { + self.output_index_mapper.ensure_mapping(upstream_index) + } + + pub fn mapped_output_index(&self, upstream_index: usize) -> Option { + self.output_index_mapper.lookup(upstream_index) + } + + pub fn allocate_synthetic_output_index(&mut self) -> usize { + self.output_index_mapper.allocate_synthetic() + } + + pub fn next_output_index(&self) -> usize { + self.output_index_mapper.next_index() + } + + pub fn original_response_id(&self) -> Option<&str> { + self.original_response_id + .as_deref() + .or_else(|| self.accumulator.original_response_id()) + } + + pub fn snapshot_final_response(&self) -> Option { + self.accumulator.snapshot_final_response() + } + + /// Process an SSE event and determine what action to take + pub fn process_event(&mut self, event_name: Option<&str>, data: &str) -> StreamAction { + // Always feed to accumulator for storage + self.accumulator.ingest_block(&format!( + "{}data: {}", + event_name + .map(|n| format!("event: {}\n", n)) + .unwrap_or_default(), + data + )); + + let parsed: Value = match serde_json::from_str(data) { + Ok(v) => v, + Err(_) => return StreamAction::Forward, + }; + + let event_type = event_name + .map(|s| s.to_string()) + .or_else(|| { + parsed + .get("type") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + }) + .unwrap_or_default(); + + match event_type.as_str() { + event_types::RESPONSE_CREATED => { + if self.original_response_id.is_none() { + if let Some(response_obj) = parsed.get("response").and_then(|v| v.as_object()) { + if let Some(id) = response_obj.get("id").and_then(|v| v.as_str()) { + self.original_response_id = Some(id.to_string()); + } + } + } + StreamAction::Forward + } + event_types::RESPONSE_COMPLETED => StreamAction::Forward, + event_types::OUTPUT_ITEM_ADDED => { + if let Some(idx) = parsed.get("output_index").and_then(|v| v.as_u64()) { + self.ensure_output_index(idx as usize); + } + + // Check if this is a function_call item being added + if let Some(item) = parsed.get("item") { + if let Some(item_type) = item.get("type").and_then(|v| v.as_str()) { + if item_type == event_types::ITEM_TYPE_FUNCTION_CALL + || item_type == event_types::ITEM_TYPE_FUNCTION_TOOL_CALL + { + match parsed.get("output_index").and_then(|v| v.as_u64()) { + Some(idx) => { + let output_index = idx as usize; + let assigned_index = self.ensure_output_index(output_index); + let call_id = + item.get("call_id").and_then(|v| v.as_str()).unwrap_or(""); + let name = + item.get("name").and_then(|v| v.as_str()).unwrap_or(""); + + // Create or update the function call + let call = self.get_or_create_call(output_index, item); + call.call_id = call_id.to_string(); + call.name = name.to_string(); + call.assigned_output_index = Some(assigned_index); + + self.in_function_call = true; + } + None => { + warn!( + "Missing output_index in function_call added event, \ + forwarding without processing for tool execution" + ); + } + } + } + } + } + StreamAction::Forward + } + event_types::FUNCTION_CALL_ARGUMENTS_DELTA => { + // Accumulate arguments for the function call + if let Some(output_index) = parsed + .get("output_index") + .and_then(|v| v.as_u64()) + .map(|v| v as usize) + { + let assigned_index = self.ensure_output_index(output_index); + if let Some(delta) = parsed.get("delta").and_then(|v| v.as_str()) { + if let Some(call) = self + .pending_calls + .iter_mut() + .find(|c| c.output_index == output_index) + { + call.arguments_buffer.push_str(delta); + if let Some(obfuscation) = + parsed.get("obfuscation").and_then(|v| v.as_str()) + { + call.last_obfuscation = Some(obfuscation.to_string()); + } + if call.assigned_output_index.is_none() { + call.assigned_output_index = Some(assigned_index); + } + } + } + } + StreamAction::Forward + } + event_types::FUNCTION_CALL_ARGUMENTS_DONE => { + // Function call arguments complete - check if ready to execute + if let Some(output_index) = parsed + .get("output_index") + .and_then(|v| v.as_u64()) + .map(|v| v as usize) + { + let assigned_index = self.ensure_output_index(output_index); + if let Some(call) = self + .pending_calls + .iter_mut() + .find(|c| c.output_index == output_index) + { + if call.assigned_output_index.is_none() { + call.assigned_output_index = Some(assigned_index); + } + } + } + + if self.has_complete_calls() { + StreamAction::ExecuteTools + } else { + StreamAction::Forward + } + } + event_types::OUTPUT_ITEM_DELTA => self.process_output_delta(&parsed), + event_types::OUTPUT_ITEM_DONE => { + // Check if we have complete function calls ready to execute + if let Some(output_index) = parsed + .get("output_index") + .and_then(|v| v.as_u64()) + .map(|v| v as usize) + { + self.ensure_output_index(output_index); + } + + if self.has_complete_calls() { + StreamAction::ExecuteTools + } else { + StreamAction::Forward + } + } + _ => StreamAction::Forward, + } + } + + /// Process output delta events to detect and accumulate function calls + fn process_output_delta(&mut self, event: &Value) -> StreamAction { + let output_index = event + .get("output_index") + .and_then(|v| v.as_u64()) + .map(|v| v as usize) + .unwrap_or(0); + + let assigned_index = self.ensure_output_index(output_index); + + let delta = match event.get("delta") { + Some(d) => d, + None => return StreamAction::Forward, + }; + + // Check if this is a function call delta + let item_type = delta.get("type").and_then(|v| v.as_str()); + + if item_type == Some(event_types::ITEM_TYPE_FUNCTION_TOOL_CALL) + || item_type == Some(event_types::ITEM_TYPE_FUNCTION_CALL) + { + self.in_function_call = true; + + // Get or create function call for this output index + let call = self.get_or_create_call(output_index, delta); + call.assigned_output_index = Some(assigned_index); + + // Accumulate call_id if present + if let Some(call_id) = delta.get("call_id").and_then(|v| v.as_str()) { + call.call_id = call_id.to_string(); + } + + // Accumulate name if present + if let Some(name) = delta.get("name").and_then(|v| v.as_str()) { + call.name.push_str(name); + } + + // Accumulate arguments if present + if let Some(args) = delta.get("arguments").and_then(|v| v.as_str()) { + call.arguments_buffer.push_str(args); + } + + if let Some(obfuscation) = delta.get("obfuscation").and_then(|v| v.as_str()) { + call.last_obfuscation = Some(obfuscation.to_string()); + } + + // Buffer this event, don't forward to client + return StreamAction::Buffer; + } + + // Forward non-function-call events + StreamAction::Forward + } + + fn get_or_create_call( + &mut self, + output_index: usize, + delta: &Value, + ) -> &mut FunctionCallInProgress { + // Find existing call for this output index + if let Some(pos) = self + .pending_calls + .iter() + .position(|c| c.output_index == output_index) + { + return &mut self.pending_calls[pos]; + } + + // Create new call + let call_id = delta + .get("call_id") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + + let mut call = FunctionCallInProgress::new(call_id, output_index); + if let Some(obfuscation) = delta.get("obfuscation").and_then(|v| v.as_str()) { + call.last_obfuscation = Some(obfuscation.to_string()); + } + + self.pending_calls.push(call); + self.pending_calls + .last_mut() + .expect("Just pushed to pending_calls, must have at least one element") + } + + fn has_complete_calls(&self) -> bool { + !self.pending_calls.is_empty() && self.pending_calls.iter().all(|c| c.is_complete()) + } + + pub fn take_pending_calls(&mut self) -> Vec { + std::mem::take(&mut self.pending_calls) + } +} + +// ============================================================================ +// SSE Parsing +// ============================================================================ + +/// Parse an SSE block into event name and data +/// +/// Returns borrowed strings when possible to avoid allocations in hot paths. +/// Only allocates when multiple data lines need to be joined. +pub(super) fn parse_sse_block(block: &str) -> (Option<&str>, Cow<'_, str>) { + let mut event_name: Option<&str> = None; + let mut data_lines: Vec<&str> = Vec::new(); + + for line in block.lines() { + if let Some(rest) = line.strip_prefix("event:") { + event_name = Some(rest.trim()); + } else if let Some(rest) = line.strip_prefix("data:") { + data_lines.push(rest.trim_start()); + } + } + + let data = if data_lines.len() == 1 { + Cow::Borrowed(data_lines[0]) + } else { + Cow::Owned(data_lines.join("\n")) + }; + + (event_name, data) +} + +// ============================================================================ +// Event Transformation and Forwarding +// ============================================================================ + +/// Apply all transformations to event data in-place (rewrite + transform) +/// Optimized to parse JSON only once instead of multiple times +/// Returns true if any changes were made +pub(super) fn apply_event_transformations_inplace( + parsed_data: &mut Value, + server_label: &str, + original_request: &ResponsesRequest, + previous_response_id: Option<&str>, +) -> bool { + let mut changed = false; + + // 1. Apply rewrite_streaming_block logic (store, previous_response_id, tools masking) + let event_type = parsed_data + .get("type") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + .unwrap_or_default(); + + let should_patch = matches!( + event_type.as_str(), + event_types::RESPONSE_CREATED + | event_types::RESPONSE_IN_PROGRESS + | event_types::RESPONSE_COMPLETED + ); + + if should_patch { + if let Some(response_obj) = parsed_data + .get_mut("response") + .and_then(|v| v.as_object_mut()) + { + let desired_store = Value::Bool(original_request.store); + if response_obj.get("store") != Some(&desired_store) { + response_obj.insert("store".to_string(), desired_store); + changed = true; + } + + if let Some(prev_id) = previous_response_id { + let needs_previous = response_obj + .get("previous_response_id") + .map(|v| v.is_null() || v.as_str().map(|s| s.is_empty()).unwrap_or(false)) + .unwrap_or(true); + + if needs_previous { + response_obj.insert( + "previous_response_id".to_string(), + Value::String(prev_id.to_string()), + ); + changed = true; + } + } + + // Mask tools from function to MCP format (optimized without cloning) + if response_obj.get("tools").is_some() { + let requested_mcp = original_request + .tools + .iter() + .any(|t| matches!(t.r#type, ResponseToolType::Mcp)); + + if requested_mcp { + if let Some(mcp_tools) = build_mcp_tools_value(original_request) { + response_obj.insert("tools".to_string(), mcp_tools); + response_obj + .entry("tool_choice".to_string()) + .or_insert(Value::String("auto".to_string())); + changed = true; + } + } + } + } + } + + // 2. Apply transform_streaming_event logic (function_call → mcp_call) + match event_type.as_str() { + event_types::OUTPUT_ITEM_ADDED | event_types::OUTPUT_ITEM_DONE => { + if let Some(item) = parsed_data.get_mut("item") { + if let Some(item_type) = item.get("type").and_then(|v| v.as_str()) { + if item_type == event_types::ITEM_TYPE_FUNCTION_CALL + || item_type == event_types::ITEM_TYPE_FUNCTION_TOOL_CALL + { + item["type"] = json!(event_types::ITEM_TYPE_MCP_CALL); + item["server_label"] = json!(server_label); + + // Transform ID from fc_* to mcp_* + if let Some(id) = item.get("id").and_then(|v| v.as_str()) { + if let Some(stripped) = id.strip_prefix("fc_") { + let new_id = format!("mcp_{}", stripped); + item["id"] = json!(new_id); + } + } + + changed = true; + } + } + } + } + event_types::FUNCTION_CALL_ARGUMENTS_DONE => { + parsed_data["type"] = json!(event_types::MCP_CALL_ARGUMENTS_DONE); + + // Transform item_id from fc_* to mcp_* + if let Some(item_id) = parsed_data.get("item_id").and_then(|v| v.as_str()) { + if let Some(stripped) = item_id.strip_prefix("fc_") { + let new_id = format!("mcp_{}", stripped); + parsed_data["item_id"] = json!(new_id); + } + } + + changed = true; + } + _ => {} + } + + changed +} + +/// Helper to build MCP tools value +fn build_mcp_tools_value(original_body: &ResponsesRequest) -> Option { + let mcp_tool = original_body + .tools + .iter() + .find(|t| matches!(t.r#type, ResponseToolType::Mcp) && t.server_url.is_some())?; + + let tools_array = vec![json!({ + "type": "mcp", + "server_label": mcp_tool.server_label, + "server_url": mcp_tool.server_url + })]; + + Some(Value::Array(tools_array)) +} + +/// Forward and transform a streaming event to the client +/// Returns false if client disconnected +#[allow(clippy::too_many_arguments)] +pub(super) fn forward_streaming_event( + raw_block: &str, + event_name: Option<&str>, + data: &str, + handler: &mut StreamingToolHandler, + tx: &mpsc::UnboundedSender>, + server_label: &str, + original_request: &ResponsesRequest, + previous_response_id: Option<&str>, + sequence_number: &mut u64, +) -> bool { + // Skip individual function_call_arguments.delta events - we'll send them as one + if event_name == Some(event_types::FUNCTION_CALL_ARGUMENTS_DELTA) { + return true; + } + + // Parse JSON data once (optimized!) + let mut parsed_data: Value = match serde_json::from_str(data) { + Ok(v) => v, + Err(_) => { + // If parsing fails, forward raw block as-is + let chunk_to_send = format!("{}\n\n", raw_block); + return tx.send(Ok(Bytes::from(chunk_to_send))).is_ok(); + } + }; + + let event_type = event_name + .or_else(|| parsed_data.get("type").and_then(|v| v.as_str())) + .unwrap_or(""); + + if event_type == event_types::RESPONSE_COMPLETED { + return true; + } + + // Check if this is function_call_arguments.done - need to send buffered args first + let mut mapped_output_index: Option = None; + + if event_name == Some(event_types::FUNCTION_CALL_ARGUMENTS_DONE) { + if let Some(output_index) = parsed_data + .get("output_index") + .and_then(|v| v.as_u64()) + .map(|v| v as usize) + { + let assigned_index = handler + .mapped_output_index(output_index) + .unwrap_or(output_index); + mapped_output_index = Some(assigned_index); + + if let Some(call) = handler + .pending_calls + .iter() + .find(|c| c.output_index == output_index) + { + let arguments_value = if call.arguments_buffer.is_empty() { + "{}".to_string() + } else { + call.arguments_buffer.clone() + }; + + // Make sure the done event carries full arguments + parsed_data["arguments"] = Value::String(arguments_value.clone()); + + // Get item_id and transform it + let item_id = parsed_data + .get("item_id") + .and_then(|v| v.as_str()) + .unwrap_or(""); + let mcp_item_id = if let Some(stripped) = item_id.strip_prefix("fc_") { + format!("mcp_{}", stripped) + } else { + item_id.to_string() + }; + + // Emit a synthetic MCP arguments delta event before the done event + let mut delta_event = json!({ + "type": event_types::MCP_CALL_ARGUMENTS_DELTA, + "sequence_number": *sequence_number, + "output_index": assigned_index, + "item_id": mcp_item_id, + "delta": arguments_value, + }); + + if let Some(obfuscation) = call.last_obfuscation.as_ref() { + if let Some(obj) = delta_event.as_object_mut() { + obj.insert( + "obfuscation".to_string(), + Value::String(obfuscation.clone()), + ); + } + } else if let Some(obfuscation) = parsed_data.get("obfuscation").cloned() { + if let Some(obj) = delta_event.as_object_mut() { + obj.insert("obfuscation".to_string(), obfuscation); + } + } + + let delta_block = format!( + "event: {}\ndata: {}\n\n", + event_types::MCP_CALL_ARGUMENTS_DELTA, + delta_event + ); + if tx.send(Ok(Bytes::from(delta_block))).is_err() { + return false; + } + + *sequence_number += 1; + } + } + } + + // Remap output_index (if present) so downstream sees sequential indices + if mapped_output_index.is_none() { + if let Some(output_index) = parsed_data + .get("output_index") + .and_then(|v| v.as_u64()) + .map(|v| v as usize) + { + mapped_output_index = handler.mapped_output_index(output_index); + } + } + + if let Some(mapped) = mapped_output_index { + parsed_data["output_index"] = json!(mapped); + } + + // Apply all transformations in-place (single parse/serialize!) + apply_event_transformations_inplace( + &mut parsed_data, + server_label, + original_request, + previous_response_id, + ); + + if let Some(response_obj) = parsed_data + .get_mut("response") + .and_then(|v| v.as_object_mut()) + { + if let Some(original_id) = handler.original_response_id() { + response_obj.insert("id".to_string(), Value::String(original_id.to_string())); + } + } + + // Update sequence number if present in the event + if parsed_data.get("sequence_number").is_some() { + parsed_data["sequence_number"] = json!(*sequence_number); + *sequence_number += 1; + } + + // Serialize once + let final_data = match serde_json::to_string(&parsed_data) { + Ok(s) => s, + Err(_) => { + // Serialization failed, forward original + let chunk_to_send = format!("{}\n\n", raw_block); + return tx.send(Ok(Bytes::from(chunk_to_send))).is_ok(); + } + }; + + // Rebuild SSE block with potentially transformed event name + let mut final_block = String::new(); + if let Some(evt) = event_name { + // Update event name for function_call_arguments events + if evt == event_types::FUNCTION_CALL_ARGUMENTS_DELTA { + final_block.push_str(&format!( + "event: {}\n", + event_types::MCP_CALL_ARGUMENTS_DELTA + )); + } else if evt == event_types::FUNCTION_CALL_ARGUMENTS_DONE { + final_block.push_str(&format!( + "event: {}\n", + event_types::MCP_CALL_ARGUMENTS_DONE + )); + } else { + final_block.push_str(&format!("event: {}\n", evt)); + } + } + final_block.push_str(&format!("data: {}", final_data)); + + let chunk_to_send = format!("{}\n\n", final_block); + if tx.send(Ok(Bytes::from(chunk_to_send))).is_err() { + return false; + } + + // After sending output_item.added for mcp_call, inject mcp_call.in_progress event + if event_name == Some(event_types::OUTPUT_ITEM_ADDED) { + if let Some(item) = parsed_data.get("item") { + if item.get("type").and_then(|v| v.as_str()) == Some(event_types::ITEM_TYPE_MCP_CALL) { + // Already transformed to mcp_call + if let (Some(item_id), Some(output_index)) = ( + item.get("id").and_then(|v| v.as_str()), + parsed_data.get("output_index").and_then(|v| v.as_u64()), + ) { + let in_progress_event = json!({ + "type": event_types::MCP_CALL_IN_PROGRESS, + "sequence_number": *sequence_number, + "output_index": output_index, + "item_id": item_id + }); + *sequence_number += 1; + let in_progress_block = format!( + "event: {}\ndata: {}\n\n", + event_types::MCP_CALL_IN_PROGRESS, + in_progress_event + ); + if tx.send(Ok(Bytes::from(in_progress_block))).is_err() { + return false; + } + } + } + } + } + + true +} + +/// Send final response.completed event to client +/// Returns false if client disconnected +#[allow(clippy::too_many_arguments)] +pub(super) fn send_final_response_event( + handler: &StreamingToolHandler, + tx: &mpsc::UnboundedSender>, + sequence_number: &mut u64, + state: &ToolLoopState, + active_mcp: Option<&Arc>, + original_request: &ResponsesRequest, + previous_response_id: Option<&str>, + server_label: &str, +) -> bool { + let mut final_response = match handler.snapshot_final_response() { + Some(resp) => resp, + None => { + warn!("Final response snapshot unavailable; skipping synthetic completion event"); + return true; + } + }; + + if let Some(original_id) = handler.original_response_id() { + if let Some(obj) = final_response.as_object_mut() { + obj.insert("id".to_string(), Value::String(original_id.to_string())); + } + } + + if let Some(mcp) = active_mcp { + inject_mcp_metadata_streaming(&mut final_response, state, mcp, server_label); + } + + mask_tools_as_mcp(&mut final_response, original_request); + patch_streaming_response_json(&mut final_response, original_request, previous_response_id); + + if let Some(obj) = final_response.as_object_mut() { + obj.insert("status".to_string(), Value::String("completed".to_string())); + } + + let completed_payload = json!({ + "type": event_types::RESPONSE_COMPLETED, + "sequence_number": *sequence_number, + "response": final_response + }); + *sequence_number += 1; + + let completed_event = format!( + "event: {}\ndata: {}\n\n", + event_types::RESPONSE_COMPLETED, + completed_payload + ); + tx.send(Ok(Bytes::from(completed_event))).is_ok() +} + +// ============================================================================ +// Main Streaming Handlers +// ============================================================================ + +/// Simple pass-through streaming without MCP interception +#[allow(clippy::too_many_arguments)] +pub(super) async fn handle_simple_streaming_passthrough( + client: &reqwest::Client, + circuit_breaker: &crate::core::CircuitBreaker, + response_storage: SharedResponseStorage, + conversation_storage: SharedConversationStorage, + conversation_item_storage: SharedConversationItemStorage, + url: String, + headers: Option<&HeaderMap>, + payload: Value, + original_body: &ResponsesRequest, + original_previous_response_id: Option, +) -> Response { + let mut request_builder = client.post(&url).json(&payload); + + if let Some(headers) = headers { + request_builder = apply_request_headers(headers, request_builder, true); + } + + request_builder = request_builder.header("Accept", "text/event-stream"); + + let response = match request_builder.send().await { + Ok(resp) => resp, + Err(err) => { + circuit_breaker.record_failure(); + return ( + StatusCode::BAD_GATEWAY, + format!("Failed to forward request to OpenAI: {}", err), + ) + .into_response(); + } + }; + + let status = response.status(); + let status_code = + StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); + + if !status.is_success() { + circuit_breaker.record_failure(); + let error_body = match response.text().await { + Ok(body) => body, + Err(err) => format!("Failed to read upstream error body: {}", err), + }; + return (status_code, error_body).into_response(); + } + + circuit_breaker.record_success(); + + let preserved_headers = preserve_response_headers(response.headers()); + let mut upstream_stream = response.bytes_stream(); + + let (tx, rx) = mpsc::unbounded_channel::>(); + + let should_store = original_body.store; + let original_request = original_body.clone(); + let persist_needed = original_request.conversation.is_some(); + let previous_response_id = original_previous_response_id.clone(); + + tokio::spawn(async move { + let mut accumulator = StreamingResponseAccumulator::new(); + let mut upstream_failed = false; + let mut receiver_connected = true; + let mut pending = String::new(); + + while let Some(chunk_result) = upstream_stream.next().await { + match chunk_result { + Ok(chunk) => { + let chunk_text = match std::str::from_utf8(&chunk) { + Ok(text) => Cow::Borrowed(text), + Err(_) => Cow::Owned(String::from_utf8_lossy(&chunk).to_string()), + }; + + pending.push_str(&chunk_text.replace("\r\n", "\n")); + + while let Some(pos) = pending.find("\n\n") { + let raw_block = pending[..pos].to_string(); + pending.drain(..pos + 2); + + if raw_block.trim().is_empty() { + continue; + } + + let block_cow = if let Some(modified) = rewrite_streaming_block( + raw_block.as_str(), + &original_request, + previous_response_id.as_deref(), + ) { + Cow::Owned(modified) + } else { + Cow::Borrowed(raw_block.as_str()) + }; + + if should_store || persist_needed { + accumulator.ingest_block(block_cow.as_ref()); + } + + if receiver_connected { + let chunk_to_send = format!("{}\n\n", block_cow); + if tx.send(Ok(Bytes::from(chunk_to_send))).is_err() { + receiver_connected = false; + } + } + + if !receiver_connected && !should_store { + break; + } + } + + if !receiver_connected && !should_store { + break; + } + } + Err(err) => { + upstream_failed = true; + let io_err = io::Error::other(err); + let _ = tx.send(Err(io_err)); + break; + } + } + } + + if (should_store || persist_needed) && !upstream_failed { + if !pending.trim().is_empty() { + accumulator.ingest_block(&pending); + } + let encountered_error = accumulator.encountered_error().cloned(); + if let Some(mut response_json) = accumulator.into_final_response() { + patch_streaming_response_json( + &mut response_json, + &original_request, + previous_response_id.as_deref(), + ); + + if persist_needed { + if let Err(err) = persist_conversation_items( + conversation_storage.clone(), + conversation_item_storage.clone(), + response_storage.clone(), + &response_json, + &original_request, + ) + .await + { + warn!("Failed to persist conversation items (stream): {}", err); + } + } else if should_store { + // Store response only if no conversation (persist_conversation_items already stores it) + if let Err(err) = + store_response_impl(&response_storage, &response_json, &original_request) + .await + { + warn!("Failed to store streaming response: {}", err); + } + } + } else if let Some(error_payload) = encountered_error { + warn!("Upstream streaming error payload: {}", error_payload); + } else { + warn!("Streaming completed without a final response payload"); + } + } + }); + + let body_stream = UnboundedReceiverStream::new(rx); + let mut response = Response::new(Body::from_stream(body_stream)); + *response.status_mut() = status_code; + + let headers_mut = response.headers_mut(); + for (name, value) in preserved_headers.iter() { + headers_mut.insert(name, value.clone()); + } + + if !headers_mut.contains_key(CONTENT_TYPE) { + headers_mut.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream")); + } + + response +} + +/// Handle streaming WITH MCP tool call interception and execution +#[allow(clippy::too_many_arguments)] +pub(super) async fn handle_streaming_with_tool_interception( + client: &reqwest::Client, + response_storage: SharedResponseStorage, + conversation_storage: SharedConversationStorage, + conversation_item_storage: SharedConversationItemStorage, + url: String, + headers: Option<&HeaderMap>, + mut payload: Value, + original_body: &ResponsesRequest, + original_previous_response_id: Option, + active_mcp: &Arc, +) -> Response { + // Transform MCP tools to function tools in payload + prepare_mcp_payload_for_streaming(&mut payload, active_mcp); + + let (tx, rx) = mpsc::unbounded_channel::>(); + let should_store = original_body.store; + let original_request = original_body.clone(); + let persist_needed = original_request.conversation.is_some(); + let previous_response_id = original_previous_response_id.clone(); + + let client_clone = client.clone(); + let url_clone = url.clone(); + let headers_opt = headers.cloned(); + let payload_clone = payload.clone(); + let active_mcp_clone = Arc::clone(active_mcp); + + // Spawn the streaming loop task + tokio::spawn(async move { + let mut state = ToolLoopState::new(original_request.input.clone()); + let loop_config = McpLoopConfig::default(); + let max_tool_calls = original_request.max_tool_calls.map(|n| n as usize); + let tools_json = payload_clone.get("tools").cloned().unwrap_or(json!([])); + let base_payload = payload_clone.clone(); + let mut current_payload = payload_clone; + let mut mcp_list_tools_sent = false; + let mut is_first_iteration = true; + let mut sequence_number: u64 = 0; + let mut next_output_index: usize = 0; + let mut preserved_response_id: Option = None; + + let server_label = original_request + .tools + .iter() + .find(|t| matches!(t.r#type, ResponseToolType::Mcp)) + .and_then(|t| t.server_label.as_deref()) + .unwrap_or("mcp"); + + loop { + // Make streaming request + let mut request_builder = client_clone.post(&url_clone).json(¤t_payload); + if let Some(ref h) = headers_opt { + request_builder = apply_request_headers(h, request_builder, true); + } + request_builder = request_builder.header("Accept", "text/event-stream"); + + let response = match request_builder.send().await { + Ok(r) => r, + Err(e) => { + let error_event = format!( + "event: error\ndata: {{\"error\": {{\"message\": \"{}\"}}}}\n\n", + e + ); + let _ = tx.send(Ok(Bytes::from(error_event))); + return; + } + }; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + let error_event = format!("event: error\ndata: {{\"error\": {{\"message\": \"Upstream error {}: {}\"}}}}\n\n", status, body); + let _ = tx.send(Ok(Bytes::from(error_event))); + return; + } + + // Stream events and check for tool calls + let mut upstream_stream = response.bytes_stream(); + let mut handler = StreamingToolHandler::with_starting_index(next_output_index); + if let Some(ref id) = preserved_response_id { + handler.original_response_id = Some(id.clone()); + } + let mut pending = String::new(); + let mut tool_calls_detected = false; + let mut seen_in_progress = false; + + while let Some(chunk_result) = upstream_stream.next().await { + match chunk_result { + Ok(chunk) => { + let chunk_text = match std::str::from_utf8(&chunk) { + Ok(text) => Cow::Borrowed(text), + Err(_) => Cow::Owned(String::from_utf8_lossy(&chunk).to_string()), + }; + + pending.push_str(&chunk_text.replace("\r\n", "\n")); + + while let Some(pos) = pending.find("\n\n") { + let raw_block = pending[..pos].to_string(); + pending.drain(..pos + 2); + + if raw_block.trim().is_empty() { + continue; + } + + // Parse event + let (event_name, data) = parse_sse_block(&raw_block); + + if data.is_empty() { + continue; + } + + // Process through handler + let action = handler.process_event(event_name, data.as_ref()); + + match action { + StreamAction::Forward => { + // Skip response.created and response.in_progress on subsequent iterations + let should_skip = if !is_first_iteration { + if let Ok(parsed) = + serde_json::from_str::(data.as_ref()) + { + matches!( + parsed.get("type").and_then(|v| v.as_str()), + Some(event_types::RESPONSE_CREATED) + | Some(event_types::RESPONSE_IN_PROGRESS) + ) + } else { + false + } + } else { + false + }; + + if !should_skip { + // Forward the event + if !forward_streaming_event( + &raw_block, + event_name, + data.as_ref(), + &mut handler, + &tx, + server_label, + &original_request, + previous_response_id.as_deref(), + &mut sequence_number, + ) { + // Client disconnected + return; + } + } + + // After forwarding response.in_progress, send mcp_list_tools events (once) + if !seen_in_progress { + if let Ok(parsed) = + serde_json::from_str::(data.as_ref()) + { + if parsed.get("type").and_then(|v| v.as_str()) + == Some(event_types::RESPONSE_IN_PROGRESS) + { + seen_in_progress = true; + if !mcp_list_tools_sent { + let list_tools_index = + handler.allocate_synthetic_output_index(); + if !send_mcp_list_tools_events( + &tx, + &active_mcp_clone, + server_label, + list_tools_index, + &mut sequence_number, + ) { + // Client disconnected + return; + } + mcp_list_tools_sent = true; + } + } + } + } + } + StreamAction::Buffer => { + // Don't forward, just buffer + } + StreamAction::ExecuteTools => { + if !forward_streaming_event( + &raw_block, + event_name, + data.as_ref(), + &mut handler, + &tx, + server_label, + &original_request, + previous_response_id.as_deref(), + &mut sequence_number, + ) { + // Client disconnected + return; + } + tool_calls_detected = true; + break; // Exit stream processing to execute tools + } + } + } + + if tool_calls_detected { + break; + } + } + Err(e) => { + let error_event = format!("event: error\ndata: {{\"error\": {{\"message\": \"Stream error: {}\"}}}}\n\n", e); + let _ = tx.send(Ok(Bytes::from(error_event))); + return; + } + } + } + + next_output_index = handler.next_output_index(); + if let Some(id) = handler.original_response_id().map(|s| s.to_string()) { + preserved_response_id = Some(id); + } + + // If no tool calls, we're done - stream is complete + if !tool_calls_detected { + if !send_final_response_event( + &handler, + &tx, + &mut sequence_number, + &state, + Some(&active_mcp_clone), + &original_request, + previous_response_id.as_deref(), + server_label, + ) { + return; + } + + let final_response_json = if should_store || persist_needed { + handler.accumulator.into_final_response() + } else { + None + }; + + if let Some(mut response_json) = final_response_json { + if let Some(ref id) = preserved_response_id { + if let Some(obj) = response_json.as_object_mut() { + obj.insert("id".to_string(), Value::String(id.clone())); + } + } + inject_mcp_metadata_streaming( + &mut response_json, + &state, + &active_mcp_clone, + server_label, + ); + + mask_tools_as_mcp(&mut response_json, &original_request); + patch_streaming_response_json( + &mut response_json, + &original_request, + previous_response_id.as_deref(), + ); + + if persist_needed { + if let Err(err) = persist_conversation_items( + conversation_storage.clone(), + conversation_item_storage.clone(), + response_storage.clone(), + &response_json, + &original_request, + ) + .await + { + warn!( + "Failed to persist conversation items (stream + MCP): {}", + err + ); + } + } else if should_store { + // Store response only if no conversation (persist_conversation_items already stores it) + if let Err(err) = store_response_impl( + &response_storage, + &response_json, + &original_request, + ) + .await + { + warn!("Failed to store streaming response: {}", err); + } + } + } + + let _ = tx.send(Ok(Bytes::from("data: [DONE]\n\n"))); + return; + } + + // Execute tools + let pending_calls = handler.take_pending_calls(); + + // Check iteration limit + state.iteration += 1; + state.total_calls += pending_calls.len(); + + let effective_limit = match max_tool_calls { + Some(user_max) => user_max.min(loop_config.max_iterations), + None => loop_config.max_iterations, + }; + + if state.total_calls > effective_limit { + warn!( + "Reached tool call limit during streaming: {}", + effective_limit + ); + let error_event = "event: error\ndata: {\"error\": {\"message\": \"Exceeded max_tool_calls limit\"}}\n\n".to_string(); + let _ = tx.send(Ok(Bytes::from(error_event))); + let _ = tx.send(Ok(Bytes::from("data: [DONE]\n\n"))); + return; + } + + // Execute all pending tool calls + if !execute_streaming_tool_calls( + pending_calls, + &active_mcp_clone, + &tx, + &mut state, + server_label, + &mut sequence_number, + ) + .await + { + // Client disconnected during tool execution + return; + } + + // Build resume payload + match build_resume_payload( + &base_payload, + &state.conversation_history, + &state.original_input, + &tools_json, + true, // is_streaming = true + ) { + Ok(resume_payload) => { + current_payload = resume_payload; + // Mark that we're no longer on the first iteration + is_first_iteration = false; + // Continue loop to make next streaming request + } + Err(e) => { + let error_event = format!("event: error\ndata: {{\"error\": {{\"message\": \"Failed to build resume payload: {}\"}}}}\n\n", e); + let _ = tx.send(Ok(Bytes::from(error_event))); + let _ = tx.send(Ok(Bytes::from("data: [DONE]\n\n"))); + return; + } + } + } + }); + + let body_stream = UnboundedReceiverStream::new(rx); + let mut response = Response::new(Body::from_stream(body_stream)); + *response.status_mut() = StatusCode::OK; + response + .headers_mut() + .insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream")); + response +} + +/// Main entry point for handling streaming responses +/// Delegates to simple passthrough or MCP tool interception based on configuration +#[allow(clippy::too_many_arguments)] +pub(super) async fn handle_streaming_response( + client: &reqwest::Client, + circuit_breaker: &crate::core::CircuitBreaker, + mcp_manager: Option<&Arc>, + response_storage: SharedResponseStorage, + conversation_storage: SharedConversationStorage, + conversation_item_storage: SharedConversationItemStorage, + url: String, + headers: Option<&HeaderMap>, + payload: Value, + original_body: &ResponsesRequest, + original_previous_response_id: Option, +) -> Response { + // Check if MCP is active for this request + let req_mcp_manager = mcp_manager_from_request_tools(&original_body.tools).await; + let active_mcp = req_mcp_manager.as_ref().or(mcp_manager); + + // If no MCP is active, use simple pass-through streaming + if active_mcp.is_none() { + return handle_simple_streaming_passthrough( + client, + circuit_breaker, + response_storage, + conversation_storage, + conversation_item_storage, + url, + headers, + payload, + original_body, + original_previous_response_id, + ) + .await; + } + + let active_mcp = active_mcp.unwrap(); + + // MCP is active - transform tools and set up interception + handle_streaming_with_tool_interception( + client, + response_storage, + conversation_storage, + conversation_item_storage, + url, + headers, + payload, + original_body, + original_previous_response_id, + active_mcp, + ) + .await +} diff --git a/sgl-router/src/routers/openai/utils.rs b/sgl-router/src/routers/openai/utils.rs new file mode 100644 index 000000000..21b80d054 --- /dev/null +++ b/sgl-router/src/routers/openai/utils.rs @@ -0,0 +1,100 @@ +//! Utility types and constants for OpenAI router + +use std::collections::HashMap; + +// ============================================================================ +// SSE Event Type Constants +// ============================================================================ + +/// SSE event type constants - single source of truth for event type strings +pub(crate) mod event_types { + // Response lifecycle events + pub const RESPONSE_CREATED: &str = "response.created"; + pub const RESPONSE_IN_PROGRESS: &str = "response.in_progress"; + pub const RESPONSE_COMPLETED: &str = "response.completed"; + + // Output item events + pub const OUTPUT_ITEM_ADDED: &str = "response.output_item.added"; + pub const OUTPUT_ITEM_DONE: &str = "response.output_item.done"; + pub const OUTPUT_ITEM_DELTA: &str = "response.output_item.delta"; + + // Function call events + pub const FUNCTION_CALL_ARGUMENTS_DELTA: &str = "response.function_call_arguments.delta"; + pub const FUNCTION_CALL_ARGUMENTS_DONE: &str = "response.function_call_arguments.done"; + + // MCP call events + pub const MCP_CALL_ARGUMENTS_DELTA: &str = "response.mcp_call_arguments.delta"; + pub const MCP_CALL_ARGUMENTS_DONE: &str = "response.mcp_call_arguments.done"; + pub const MCP_CALL_IN_PROGRESS: &str = "response.mcp_call.in_progress"; + pub const MCP_CALL_COMPLETED: &str = "response.mcp_call.completed"; + pub const MCP_LIST_TOOLS_IN_PROGRESS: &str = "response.mcp_list_tools.in_progress"; + pub const MCP_LIST_TOOLS_COMPLETED: &str = "response.mcp_list_tools.completed"; + + // Item types + pub const ITEM_TYPE_FUNCTION_CALL: &str = "function_call"; + pub const ITEM_TYPE_FUNCTION_TOOL_CALL: &str = "function_tool_call"; + pub const ITEM_TYPE_MCP_CALL: &str = "mcp_call"; + pub const ITEM_TYPE_FUNCTION: &str = "function"; + pub const ITEM_TYPE_MCP_LIST_TOOLS: &str = "mcp_list_tools"; +} + +// ============================================================================ +// Stream Action Enum +// ============================================================================ + +/// Action to take based on streaming event processing +#[derive(Debug)] +pub(crate) enum StreamAction { + Forward, // Pass event to client + Buffer, // Accumulate for tool execution + ExecuteTools, // Function call complete, execute now +} + +// ============================================================================ +// Output Index Mapper +// ============================================================================ + +/// Maps upstream output indices to sequential downstream indices +#[derive(Debug, Default)] +pub(crate) struct OutputIndexMapper { + next_index: usize, + // Map upstream output_index -> remapped output_index + assigned: HashMap, +} + +impl OutputIndexMapper { + pub fn with_start(next_index: usize) -> Self { + Self { + next_index, + assigned: HashMap::new(), + } + } + + pub fn ensure_mapping(&mut self, upstream_index: usize) -> usize { + *self.assigned.entry(upstream_index).or_insert_with(|| { + let assigned = self.next_index; + self.next_index += 1; + assigned + }) + } + + pub fn lookup(&self, upstream_index: usize) -> Option { + self.assigned.get(&upstream_index).copied() + } + + pub fn allocate_synthetic(&mut self) -> usize { + let assigned = self.next_index; + self.next_index += 1; + assigned + } + + pub fn next_index(&self) -> usize { + self.next_index + } +} + +// ============================================================================ +// Re-export FunctionCallInProgress from mcp module +// ============================================================================ + +pub(crate) use super::mcp::FunctionCallInProgress; diff --git a/sgl-router/tests/test_openai_routing.rs b/sgl-router/tests/test_openai_routing.rs index 1b53ed42f..f65ed5ace 100644 --- a/sgl-router/tests/test_openai_routing.rs +++ b/sgl-router/tests/test_openai_routing.rs @@ -22,7 +22,7 @@ use sglang_router_rs::{ ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, ResponseInput, ResponsesGetParams, ResponsesRequest, UserMessageContent, }, - routers::{openai_router::OpenAIRouter, RouterTrait}, + routers::{openai::OpenAIRouter, RouterTrait}, }; use std::collections::HashMap; use std::sync::{