[router] Refactor OpenAI router: split monolithic file and move location (#11359)

This commit is contained in:
Keyang Ru
2025-10-08 21:46:39 -07:00
committed by GitHub
parent 368fd20622
commit 84768d1017
12 changed files with 4492 additions and 4552 deletions

View File

@@ -3,7 +3,8 @@
use super::grpc::pd_router::GrpcPDRouter;
use super::grpc::router::GrpcRouter;
use super::{
http::{openai_router::OpenAIRouter, pd_router::PDRouter, router::Router},
http::{pd_router::PDRouter, router::Router},
openai::OpenAIRouter,
RouterTrait,
};
use crate::config::{ConnectionMode, PolicyConfig, RoutingMode};

View File

@@ -1,6 +1,5 @@
//! HTTP router implementations
pub mod openai_router;
pub mod pd_router;
pub mod pd_types;
pub mod router;

File diff suppressed because it is too large Load Diff

View File

@@ -19,12 +19,13 @@ pub mod factory;
pub mod grpc;
pub mod header_utils;
pub mod http;
pub mod openai; // New refactored OpenAI router module
pub mod router_manager;
pub use factory::RouterFactory;
// Re-export HTTP routers for convenience (keeps routers::openai_router path working)
pub use http::{openai_router, pd_router, pd_types, router};
// Re-export HTTP routers for convenience
pub use http::{pd_router, pd_types, router};
/// Core trait for all router implementations
///

View File

@@ -0,0 +1,574 @@
//! Conversation CRUD operations and persistence
use crate::data_connector::{
conversation_items::ListParams, conversation_items::SortOrder, Conversation, ConversationId,
ConversationItemStorage, ConversationStorage, NewConversation, NewConversationItem, ResponseId,
ResponseStorage, SharedConversationItemStorage, SharedConversationStorage,
};
use crate::protocols::spec::{ResponseInput, ResponsesRequest};
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use axum::Json;
use chrono::Utc;
use serde_json::{json, Value};
use std::collections::HashMap;
use std::sync::Arc;
use tracing::{info, warn};
use super::responses::build_stored_response;
/// Maximum number of properties allowed in conversation metadata
pub(crate) const MAX_METADATA_PROPERTIES: usize = 16;
// ============================================================================
// Conversation CRUD Operations
// ============================================================================
/// Create a new conversation
pub(super) async fn create_conversation(
conversation_storage: &SharedConversationStorage,
body: Value,
) -> Response {
// TODO: The validation should be done in the right place
let metadata = match body.get("metadata") {
Some(Value::Object(map)) => {
if map.len() > MAX_METADATA_PROPERTIES {
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": format!(
"metadata cannot have more than {} properties",
MAX_METADATA_PROPERTIES
)
})),
)
.into_response();
}
Some(map.clone())
}
Some(_) => {
return (
StatusCode::BAD_REQUEST,
Json(json!({"error": "metadata must be an object"})),
)
.into_response();
}
None => None,
};
let new_conv = NewConversation { metadata };
match conversation_storage.create_conversation(new_conv).await {
Ok(conversation) => {
info!(conversation_id = %conversation.id.0, "Created conversation");
(StatusCode::OK, Json(conversation_to_json(&conversation))).into_response()
}
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to create conversation: {}", e)})),
)
.into_response(),
}
}
/// Get a conversation by ID
pub(super) async fn get_conversation(
conversation_storage: &SharedConversationStorage,
conv_id: &str,
) -> Response {
let conversation_id = ConversationId::from(conv_id);
match conversation_storage
.get_conversation(&conversation_id)
.await
{
Ok(Some(conversation)) => {
(StatusCode::OK, Json(conversation_to_json(&conversation))).into_response()
}
Ok(None) => (
StatusCode::NOT_FOUND,
Json(json!({"error": "Conversation not found"})),
)
.into_response(),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to get conversation: {}", e)})),
)
.into_response(),
}
}
/// Update a conversation's metadata
pub(super) async fn update_conversation(
conversation_storage: &SharedConversationStorage,
conv_id: &str,
body: Value,
) -> Response {
let conversation_id = ConversationId::from(conv_id);
let current_meta = match conversation_storage
.get_conversation(&conversation_id)
.await
{
Ok(Some(meta)) => meta,
Ok(None) => {
return (
StatusCode::NOT_FOUND,
Json(json!({"error": "Conversation not found"})),
)
.into_response();
}
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to get conversation: {}", e)})),
)
.into_response();
}
};
#[derive(Debug)]
enum Patch {
Set(String, Value),
Delete(String),
}
let mut patches: Vec<Patch> = Vec::new();
if let Some(metadata_val) = body.get("metadata") {
if let Some(map) = metadata_val.as_object() {
for (k, v) in map {
if v.is_null() {
patches.push(Patch::Delete(k.clone()));
} else {
patches.push(Patch::Set(k.clone(), v.clone()));
}
}
} else {
return (
StatusCode::BAD_REQUEST,
Json(json!({"error": "metadata must be an object"})),
)
.into_response();
}
}
let mut new_metadata = current_meta.metadata.clone().unwrap_or_default();
for patch in patches {
match patch {
Patch::Set(k, v) => {
new_metadata.insert(k, v);
}
Patch::Delete(k) => {
new_metadata.remove(&k);
}
}
}
if new_metadata.len() > MAX_METADATA_PROPERTIES {
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": format!(
"metadata cannot have more than {} properties",
MAX_METADATA_PROPERTIES
)
})),
)
.into_response();
}
let final_metadata = if new_metadata.is_empty() {
None
} else {
Some(new_metadata)
};
match conversation_storage
.update_conversation(&conversation_id, final_metadata)
.await
{
Ok(Some(conversation)) => {
info!(conversation_id = %conversation_id.0, "Updated conversation");
(StatusCode::OK, Json(conversation_to_json(&conversation))).into_response()
}
Ok(None) => (
StatusCode::NOT_FOUND,
Json(json!({"error": "Conversation not found"})),
)
.into_response(),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to update conversation: {}", e)})),
)
.into_response(),
}
}
/// Delete a conversation
pub(super) async fn delete_conversation(
conversation_storage: &SharedConversationStorage,
conv_id: &str,
) -> Response {
let conversation_id = ConversationId::from(conv_id);
match conversation_storage
.get_conversation(&conversation_id)
.await
{
Ok(Some(_)) => {}
Ok(None) => {
return (
StatusCode::NOT_FOUND,
Json(json!({"error": "Conversation not found"})),
)
.into_response();
}
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to get conversation: {}", e)})),
)
.into_response();
}
}
match conversation_storage
.delete_conversation(&conversation_id)
.await
{
Ok(_) => {
info!(conversation_id = %conversation_id.0, "Deleted conversation");
(
StatusCode::OK,
Json(json!({
"id": conversation_id.0,
"object": "conversation.deleted",
"deleted": true
})),
)
.into_response()
}
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to delete conversation: {}", e)})),
)
.into_response(),
}
}
/// List items in a conversation with pagination
pub(super) async fn list_conversation_items(
conversation_storage: &SharedConversationStorage,
item_storage: &SharedConversationItemStorage,
conv_id: &str,
query_params: HashMap<String, String>,
) -> Response {
let conversation_id = ConversationId::from(conv_id);
match conversation_storage
.get_conversation(&conversation_id)
.await
{
Ok(Some(_)) => {}
Ok(None) => {
return (
StatusCode::NOT_FOUND,
Json(json!({"error": "Conversation not found"})),
)
.into_response();
}
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to get conversation: {}", e)})),
)
.into_response();
}
}
let limit: usize = query_params
.get("limit")
.and_then(|s| s.parse().ok())
.unwrap_or(100);
let after = query_params.get("after").map(|s| s.to_string());
// Default to descending order (most recent first)
let order = query_params
.get("order")
.and_then(|s| match s.as_str() {
"asc" => Some(SortOrder::Asc),
"desc" => Some(SortOrder::Desc),
_ => None,
})
.unwrap_or(SortOrder::Desc);
let params = ListParams {
limit,
order,
after,
};
match item_storage.list_items(&conversation_id, params).await {
Ok(items) => {
let item_values: Vec<Value> = items
.iter()
.map(|item| {
let mut obj = serde_json::Map::new();
obj.insert("id".to_string(), json!(item.id.0));
obj.insert("type".to_string(), json!(item.item_type));
obj.insert("created_at".to_string(), json!(item.created_at));
obj.insert("content".to_string(), item.content.clone());
if let Some(status) = &item.status {
obj.insert("status".to_string(), json!(status));
}
Value::Object(obj)
})
.collect();
let has_more = items.len() == limit;
let last_id = items.last().map(|item| item.id.0.clone());
(
StatusCode::OK,
Json(json!({
"object": "list",
"data": item_values,
"has_more": has_more,
"first_id": items.first().map(|item| &item.id.0),
"last_id": last_id,
})),
)
.into_response()
}
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to list items: {}", e)})),
)
.into_response(),
}
}
// ============================================================================
// Persistence Operations
// ============================================================================
/// Persist conversation items (delegates to persist_items_with_storages)
pub(super) async fn persist_conversation_items(
conversation_storage: Arc<dyn ConversationStorage>,
item_storage: Arc<dyn ConversationItemStorage>,
response_storage: Arc<dyn ResponseStorage>,
response_json: &Value,
original_body: &ResponsesRequest,
) -> Result<(), String> {
persist_items_with_storages(
conversation_storage,
item_storage,
response_storage,
response_json,
original_body,
)
.await
}
/// Helper function to create and link a conversation item (two-step API)
async fn create_and_link_item(
item_storage: &Arc<dyn ConversationItemStorage>,
conv_id: &ConversationId,
mut new_item: NewConversationItem,
) -> Result<(), String> {
// Set default status if not provided
if new_item.status.is_none() {
new_item.status = Some("completed".to_string());
}
// Step 1: Create the item
let created = item_storage
.create_item(new_item)
.await
.map_err(|e| format!("Failed to create item: {}", e))?;
// Step 2: Link it to the conversation
item_storage
.link_item(conv_id, &created.id, Utc::now())
.await
.map_err(|e| format!("Failed to link item: {}", e))?;
info!(
conversation_id = %conv_id.0,
item_id = %created.id.0,
item_type = %created.item_type,
"Persisted conversation item and link"
);
Ok(())
}
/// Persist conversation items with all storages
async fn persist_items_with_storages(
conversation_storage: Arc<dyn ConversationStorage>,
item_storage: Arc<dyn ConversationItemStorage>,
response_storage: Arc<dyn ResponseStorage>,
response_json: &Value,
original_body: &ResponsesRequest,
) -> Result<(), String> {
let conv_id = match &original_body.conversation {
Some(id) => ConversationId::from(id.as_str()),
None => return Ok(()),
};
if conversation_storage
.get_conversation(&conv_id)
.await
.map_err(|e| format!("Failed to get conversation: {}", e))?
.is_none()
{
warn!(conversation_id = %conv_id.0, "Conversation not found, skipping item persistence");
return Ok(());
}
let response_id_str = response_json
.get("id")
.and_then(|v| v.as_str())
.ok_or_else(|| "Response missing id field".to_string())?;
let response_id = ResponseId::from(response_id_str);
let response_id_opt = Some(response_id_str.to_string());
// Persist input items
match &original_body.input {
ResponseInput::Text(text) => {
let new_item = NewConversationItem {
id: None, // Let storage generate ID
response_id: response_id_opt.clone(),
item_type: "message".to_string(),
role: Some("user".to_string()),
content: json!([{ "type": "input_text", "text": text }]),
status: Some("completed".to_string()),
};
create_and_link_item(&item_storage, &conv_id, new_item).await?;
}
ResponseInput::Items(items_array) => {
for input_item in items_array {
match input_item {
crate::protocols::spec::ResponseInputOutputItem::Message {
role,
content,
status,
..
} => {
let content_v = serde_json::to_value(content)
.map_err(|e| format!("Failed to serialize content: {}", e))?;
let new_item = NewConversationItem {
id: None,
response_id: response_id_opt.clone(),
item_type: "message".to_string(),
role: Some(role.clone()),
content: content_v,
status: status.clone(),
};
create_and_link_item(&item_storage, &conv_id, new_item).await?;
}
_ => {
// For other types (FunctionToolCall, etc.), serialize the whole item
let item_val = serde_json::to_value(input_item)
.map_err(|e| format!("Failed to serialize item: {}", e))?;
let new_item = NewConversationItem {
id: None,
response_id: response_id_opt.clone(),
item_type: "unknown".to_string(),
role: None,
content: item_val,
status: Some("completed".to_string()),
};
create_and_link_item(&item_storage, &conv_id, new_item).await?;
}
}
}
}
}
// Persist output items
if let Some(output_arr) = response_json.get("output").and_then(|v| v.as_array()) {
for output_item in output_arr {
if let Some(obj) = output_item.as_object() {
let item_type = obj
.get("type")
.and_then(|v| v.as_str())
.unwrap_or("message");
let role = obj.get("role").and_then(|v| v.as_str()).map(String::from);
let status = obj.get("status").and_then(|v| v.as_str()).map(String::from);
let content = if item_type == "message" {
obj.get("content").cloned().unwrap_or(json!([]))
} else if item_type == "function_call" || item_type == "function_tool_call" {
json!({
"type": "function_call",
"name": obj.get("name"),
"call_id": obj.get("call_id").or_else(|| obj.get("id")),
"arguments": obj.get("arguments")
})
} else if item_type == "function_call_output" {
json!({
"type": "function_call_output",
"call_id": obj.get("call_id"),
"output": obj.get("output")
})
} else {
output_item.clone()
};
let new_item = NewConversationItem {
id: None,
response_id: response_id_opt.clone(),
item_type: item_type.to_string(),
role,
content,
status,
};
create_and_link_item(&item_storage, &conv_id, new_item).await?;
}
}
}
// Store the full response using the shared helper
let mut stored_response = build_stored_response(response_json, original_body);
stored_response.id = response_id;
let final_response_id = stored_response.id.clone();
response_storage
.store_response(stored_response)
.await
.map_err(|e| format!("Failed to store response in conversation: {}", e))?;
info!(conversation_id = %conv_id.0, response_id = %final_response_id.0, "Persisted conversation items and response");
Ok(())
}
// ============================================================================
// Helper Functions
// ============================================================================
/// Convert conversation to JSON response
fn conversation_to_json(conversation: &Conversation) -> Value {
let mut response = json!({
"id": conversation.id.0,
"object": "conversation",
"created_at": conversation.created_at.timestamp()
});
if let Some(metadata) = &conversation.metadata {
if !metadata.is_empty() {
if let Some(obj) = response.as_object_mut() {
obj.insert("metadata".to_string(), Value::Object(metadata.clone()));
}
}
}
response
}

View File

@@ -0,0 +1,967 @@
//! MCP (Model Context Protocol) Integration Module
//!
//! This module contains all MCP-related functionality for the OpenAI router:
//! - Tool loop state management for multi-turn tool calling
//! - MCP tool execution and result handling
//! - Output item builders for MCP-specific response formats
//! - SSE event generation for streaming MCP operations
//! - Payload transformation for MCP tool interception
//! - Metadata injection for MCP operations
use crate::mcp::McpClientManager;
use crate::protocols::spec::{ResponseInput, ResponseToolType, ResponsesRequest};
use crate::routers::header_utils::apply_request_headers;
use axum::http::HeaderMap;
use bytes::Bytes;
use serde_json::{json, to_value, Value};
use std::{io, sync::Arc};
use tokio::sync::mpsc;
use tracing::{info, warn};
use super::utils::event_types;
// ============================================================================
// Configuration and State Types
// ============================================================================
/// Configuration for MCP tool calling loops
#[allow(dead_code)]
#[derive(Debug, Clone)]
pub(crate) struct McpLoopConfig {
/// Maximum iterations as safety limit (internal only, default: 10)
/// Prevents infinite loops when max_tool_calls is not set
pub max_iterations: usize,
}
impl Default for McpLoopConfig {
fn default() -> Self {
Self { max_iterations: 10 }
}
}
/// State for tracking multi-turn tool calling loop
pub(crate) struct ToolLoopState {
/// Current iteration number (starts at 0, increments with each tool call)
pub iteration: usize,
/// Total number of tool calls executed
pub total_calls: usize,
/// Conversation history (function_call and function_call_output items)
pub conversation_history: Vec<Value>,
/// Original user input (preserved for building resume payloads)
pub original_input: ResponseInput,
}
impl ToolLoopState {
pub fn new(original_input: ResponseInput) -> Self {
Self {
iteration: 0,
total_calls: 0,
conversation_history: Vec::new(),
original_input,
}
}
/// Record a tool call in the loop state
pub fn record_call(
&mut self,
call_id: String,
tool_name: String,
args_json_str: String,
output_str: String,
) {
// Add function_call item to history
let func_item = json!({
"type": event_types::ITEM_TYPE_FUNCTION_CALL,
"call_id": call_id,
"name": tool_name,
"arguments": args_json_str
});
self.conversation_history.push(func_item);
// Add function_call_output item to history
let output_item = json!({
"type": "function_call_output",
"call_id": call_id,
"output": output_str
});
self.conversation_history.push(output_item);
}
}
/// Represents a function call being accumulated across delta events
#[derive(Debug, Clone)]
pub(crate) struct FunctionCallInProgress {
pub call_id: String,
pub name: String,
pub arguments_buffer: String,
pub output_index: usize,
pub last_obfuscation: Option<String>,
pub assigned_output_index: Option<usize>,
}
impl FunctionCallInProgress {
pub fn new(call_id: String, output_index: usize) -> Self {
Self {
call_id,
name: String::new(),
arguments_buffer: String::new(),
output_index,
last_obfuscation: None,
assigned_output_index: None,
}
}
pub fn is_complete(&self) -> bool {
// A tool call is complete if it has a name
!self.name.is_empty()
}
pub fn effective_output_index(&self) -> usize {
self.assigned_output_index.unwrap_or(self.output_index)
}
}
// ============================================================================
// MCP Manager Integration
// ============================================================================
/// Build a request-scoped MCP manager from request tools, if present.
pub(super) async fn mcp_manager_from_request_tools(
tools: &[crate::protocols::spec::ResponseTool],
) -> Option<Arc<McpClientManager>> {
let tool = tools
.iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp) && t.server_url.is_some())?;
let server_url = tool.server_url.as_ref()?.trim().to_string();
if !(server_url.starts_with("http://") || server_url.starts_with("https://")) {
warn!(
"Ignoring MCP server_url with unsupported scheme: {}",
server_url
);
return None;
}
let name = tool
.server_label
.clone()
.unwrap_or_else(|| "request-mcp".to_string());
let token = tool.authorization.clone();
let transport = if server_url.contains("/sse") {
crate::mcp::McpTransport::Sse {
url: server_url,
token,
}
} else {
crate::mcp::McpTransport::Streamable {
url: server_url,
token,
}
};
let cfg = crate::mcp::McpConfig {
servers: vec![crate::mcp::McpServerConfig { name, transport }],
};
match McpClientManager::new(cfg).await {
Ok(mgr) => Some(Arc::new(mgr)),
Err(err) => {
warn!("Failed to initialize request-scoped MCP manager: {}", err);
None
}
}
}
// ============================================================================
// Tool Execution
// ============================================================================
/// Execute an MCP tool call
pub(super) async fn execute_mcp_call(
mcp_mgr: &Arc<McpClientManager>,
tool_name: &str,
args_json_str: &str,
) -> Result<(String, String), String> {
let args_value: Value =
serde_json::from_str(args_json_str).map_err(|e| format!("parse tool args: {}", e))?;
let args_obj = args_value.as_object().cloned();
let server_name = mcp_mgr
.get_tool(tool_name)
.map(|t| t.server)
.ok_or_else(|| format!("tool not found: {}", tool_name))?;
let result = mcp_mgr
.call_tool(tool_name, args_obj)
.await
.map_err(|e| format!("tool call failed: {}", e))?;
let output_str = serde_json::to_string(&result)
.map_err(|e| format!("Failed to serialize tool result: {}", e))?;
Ok((server_name, output_str))
}
/// Execute detected tool calls and send completion events to client
/// Returns false if client disconnected during execution
pub(super) async fn execute_streaming_tool_calls(
pending_calls: Vec<FunctionCallInProgress>,
active_mcp: &Arc<McpClientManager>,
tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>,
state: &mut ToolLoopState,
server_label: &str,
sequence_number: &mut u64,
) -> bool {
// Execute all pending tool calls (sequential, as PR3 is skipped)
for call in pending_calls {
// Skip if name is empty (invalid call)
if call.name.is_empty() {
warn!(
"Skipping incomplete tool call: name is empty, args_len={}",
call.arguments_buffer.len()
);
continue;
}
info!(
"Executing tool call during streaming: {} ({})",
call.name, call.call_id
);
// Use empty JSON object if arguments_buffer is empty
let args_str = if call.arguments_buffer.is_empty() {
"{}"
} else {
&call.arguments_buffer
};
let call_result = execute_mcp_call(active_mcp, &call.name, args_str).await;
let (output_str, success, error_msg) = match call_result {
Ok((_, output)) => (output, true, None),
Err(err) => {
warn!("Tool execution failed during streaming: {}", err);
(json!({ "error": &err }).to_string(), false, Some(err))
}
};
// Send mcp_call completion event to client
if !send_mcp_call_completion_events_with_error(
tx,
&call,
&output_str,
server_label,
success,
error_msg.as_deref(),
sequence_number,
) {
// Client disconnected, no point continuing tool execution
return false;
}
// Record the call
state.record_call(call.call_id, call.name, call.arguments_buffer, output_str);
}
true
}
// ============================================================================
// Payload Transformation
// ============================================================================
/// Transform payload to replace MCP tools with function tools for streaming
pub(super) fn prepare_mcp_payload_for_streaming(
payload: &mut Value,
active_mcp: &Arc<McpClientManager>,
) {
if let Some(obj) = payload.as_object_mut() {
// Remove any non-function tools from outgoing payload
if let Some(v) = obj.get_mut("tools") {
if let Some(arr) = v.as_array_mut() {
arr.retain(|item| {
item.get("type")
.and_then(|v| v.as_str())
.map(|s| s == event_types::ITEM_TYPE_FUNCTION)
.unwrap_or(false)
});
}
}
// Build function tools for all discovered MCP tools
let mut tools_json = Vec::new();
let tools = active_mcp.list_tools();
for t in tools {
let parameters = t.parameters.clone().unwrap_or(serde_json::json!({
"type": "object",
"properties": {},
"additionalProperties": false
}));
let tool = serde_json::json!({
"type": event_types::ITEM_TYPE_FUNCTION,
"name": t.name,
"description": t.description,
"parameters": parameters
});
tools_json.push(tool);
}
if !tools_json.is_empty() {
obj.insert("tools".to_string(), Value::Array(tools_json));
obj.insert("tool_choice".to_string(), Value::String("auto".to_string()));
}
}
}
/// Build a resume payload with conversation history
pub(super) fn build_resume_payload(
base_payload: &Value,
conversation_history: &[Value],
original_input: &ResponseInput,
tools_json: &Value,
is_streaming: bool,
) -> Result<Value, String> {
// Clone the base payload which already has cleaned fields
let mut payload = base_payload.clone();
let obj = payload
.as_object_mut()
.ok_or_else(|| "payload not an object".to_string())?;
// Build input array: start with original user input
let mut input_array = Vec::new();
// Add original user message
// For structured input, serialize the original input items
match original_input {
ResponseInput::Text(text) => {
let user_item = json!({
"type": "message",
"role": "user",
"content": [{ "type": "input_text", "text": text }]
});
input_array.push(user_item);
}
ResponseInput::Items(items) => {
// Items are already structured ResponseInputOutputItem, convert to JSON
if let Ok(items_value) = to_value(items) {
if let Some(items_arr) = items_value.as_array() {
input_array.extend_from_slice(items_arr);
}
}
}
}
// Add all conversation history (function calls and outputs)
input_array.extend_from_slice(conversation_history);
obj.insert("input".to_string(), Value::Array(input_array));
// Use the transformed tools (function tools, not MCP tools)
if let Some(tools_arr) = tools_json.as_array() {
if !tools_arr.is_empty() {
obj.insert("tools".to_string(), tools_json.clone());
}
}
// Set streaming mode based on caller's context
obj.insert("stream".to_string(), Value::Bool(is_streaming));
obj.insert("store".to_string(), Value::Bool(false));
// Note: SGLang-specific fields were already removed from base_payload
// before it was passed to execute_tool_loop (see route_responses lines 1935-1946)
Ok(payload)
}
// ============================================================================
// SSE Event Senders
// ============================================================================
/// Send mcp_list_tools events to client at the start of streaming
/// Returns false if client disconnected
pub(super) fn send_mcp_list_tools_events(
tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>,
mcp: &Arc<McpClientManager>,
server_label: &str,
output_index: usize,
sequence_number: &mut u64,
) -> bool {
let tools_item_full = build_mcp_list_tools_item(mcp, server_label);
let item_id = tools_item_full
.get("id")
.and_then(|v| v.as_str())
.unwrap_or("");
// Create empty tools version for the initial added event
let mut tools_item_empty = tools_item_full.clone();
if let Some(obj) = tools_item_empty.as_object_mut() {
obj.insert("tools".to_string(), json!([]));
}
// Event 1: response.output_item.added with empty tools
let event1_payload = json!({
"type": event_types::OUTPUT_ITEM_ADDED,
"sequence_number": *sequence_number,
"output_index": output_index,
"item": tools_item_empty
});
*sequence_number += 1;
let event1 = format!(
"event: {}\ndata: {}\n\n",
event_types::OUTPUT_ITEM_ADDED,
event1_payload
);
if tx.send(Ok(Bytes::from(event1))).is_err() {
return false; // Client disconnected
}
// Event 2: response.mcp_list_tools.in_progress
let event2_payload = json!({
"type": event_types::MCP_LIST_TOOLS_IN_PROGRESS,
"sequence_number": *sequence_number,
"output_index": output_index,
"item_id": item_id
});
*sequence_number += 1;
let event2 = format!(
"event: {}\ndata: {}\n\n",
event_types::MCP_LIST_TOOLS_IN_PROGRESS,
event2_payload
);
if tx.send(Ok(Bytes::from(event2))).is_err() {
return false;
}
// Event 3: response.mcp_list_tools.completed
let event3_payload = json!({
"type": event_types::MCP_LIST_TOOLS_COMPLETED,
"sequence_number": *sequence_number,
"output_index": output_index,
"item_id": item_id
});
*sequence_number += 1;
let event3 = format!(
"event: {}\ndata: {}\n\n",
event_types::MCP_LIST_TOOLS_COMPLETED,
event3_payload
);
if tx.send(Ok(Bytes::from(event3))).is_err() {
return false;
}
// Event 4: response.output_item.done with full tools list
let event4_payload = json!({
"type": event_types::OUTPUT_ITEM_DONE,
"sequence_number": *sequence_number,
"output_index": output_index,
"item": tools_item_full
});
*sequence_number += 1;
let event4 = format!(
"event: {}\ndata: {}\n\n",
event_types::OUTPUT_ITEM_DONE,
event4_payload
);
tx.send(Ok(Bytes::from(event4))).is_ok()
}
/// Send mcp_call completion events after tool execution
/// Returns false if client disconnected
pub(super) fn send_mcp_call_completion_events_with_error(
tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>,
call: &FunctionCallInProgress,
output: &str,
server_label: &str,
success: bool,
error_msg: Option<&str>,
sequence_number: &mut u64,
) -> bool {
let effective_output_index = call.effective_output_index();
// Build mcp_call item (reuse existing function)
let mcp_call_item = build_mcp_call_item(
&call.name,
&call.arguments_buffer,
output,
server_label,
success,
error_msg,
);
// Get the mcp_call item_id
let item_id = mcp_call_item
.get("id")
.and_then(|v| v.as_str())
.unwrap_or("");
// Event 1: response.mcp_call.completed
let completed_payload = json!({
"type": event_types::MCP_CALL_COMPLETED,
"sequence_number": *sequence_number,
"output_index": effective_output_index,
"item_id": item_id
});
*sequence_number += 1;
let completed_event = format!(
"event: {}\ndata: {}\n\n",
event_types::MCP_CALL_COMPLETED,
completed_payload
);
if tx.send(Ok(Bytes::from(completed_event))).is_err() {
return false;
}
// Event 2: response.output_item.done (with completed mcp_call)
let done_payload = json!({
"type": event_types::OUTPUT_ITEM_DONE,
"sequence_number": *sequence_number,
"output_index": effective_output_index,
"item": mcp_call_item
});
*sequence_number += 1;
let done_event = format!(
"event: {}\ndata: {}\n\n",
event_types::OUTPUT_ITEM_DONE,
done_payload
);
tx.send(Ok(Bytes::from(done_event))).is_ok()
}
// ============================================================================
// Metadata Injection
// ============================================================================
/// Inject MCP metadata into a streaming response
pub(super) fn inject_mcp_metadata_streaming(
response: &mut Value,
state: &ToolLoopState,
mcp: &Arc<McpClientManager>,
server_label: &str,
) {
if let Some(output_array) = response.get_mut("output").and_then(|v| v.as_array_mut()) {
output_array.retain(|item| {
item.get("type").and_then(|t| t.as_str()) != Some(event_types::ITEM_TYPE_MCP_LIST_TOOLS)
});
let list_tools_item = build_mcp_list_tools_item(mcp, server_label);
output_array.insert(0, list_tools_item);
let mcp_call_items =
build_executed_mcp_call_items(&state.conversation_history, server_label);
let mut insert_pos = 1;
for item in mcp_call_items {
output_array.insert(insert_pos, item);
insert_pos += 1;
}
} else if let Some(obj) = response.as_object_mut() {
let mut output_items = Vec::new();
output_items.push(build_mcp_list_tools_item(mcp, server_label));
output_items.extend(build_executed_mcp_call_items(
&state.conversation_history,
server_label,
));
obj.insert("output".to_string(), Value::Array(output_items));
}
}
// ============================================================================
// Tool Loop Execution
// ============================================================================
/// Execute the tool calling loop
pub(super) async fn execute_tool_loop(
client: &reqwest::Client,
url: &str,
headers: Option<&HeaderMap>,
initial_payload: Value,
original_body: &ResponsesRequest,
active_mcp: &Arc<McpClientManager>,
config: &McpLoopConfig,
) -> Result<Value, String> {
let mut state = ToolLoopState::new(original_body.input.clone());
// Get max_tool_calls from request (None means no user-specified limit)
let max_tool_calls = original_body.max_tool_calls.map(|n| n as usize);
// Keep initial_payload as base template (already has fields cleaned)
let base_payload = initial_payload.clone();
let tools_json = base_payload.get("tools").cloned().unwrap_or(json!([]));
let mut current_payload = initial_payload;
info!(
"Starting tool loop: max_tool_calls={:?}, max_iterations={}",
max_tool_calls, config.max_iterations
);
loop {
// Make request to upstream
let request_builder = client.post(url).json(&current_payload);
let request_builder = if let Some(headers) = headers {
apply_request_headers(headers, request_builder, true)
} else {
request_builder
};
let response = request_builder
.send()
.await
.map_err(|e| format!("upstream request failed: {}", e))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(format!("upstream error {}: {}", status, body));
}
let mut response_json = response
.json::<Value>()
.await
.map_err(|e| format!("parse response: {}", e))?;
// Check for function call
if let Some((call_id, tool_name, args_json_str)) = extract_function_call(&response_json) {
state.iteration += 1;
state.total_calls += 1;
info!(
"Tool loop iteration {}: calling {} (call_id: {})",
state.iteration, tool_name, call_id
);
// Check combined limit: use minimum of user's max_tool_calls (if set) and safety max_iterations
let effective_limit = match max_tool_calls {
Some(user_max) => user_max.min(config.max_iterations),
None => config.max_iterations,
};
if state.total_calls > effective_limit {
if let Some(user_max) = max_tool_calls {
if state.total_calls > user_max {
warn!("Reached user-specified max_tool_calls limit: {}", user_max);
} else {
warn!(
"Reached safety max_iterations limit: {}",
config.max_iterations
);
}
} else {
warn!(
"Reached safety max_iterations limit: {}",
config.max_iterations
);
}
return build_incomplete_response(
response_json,
state,
"max_tool_calls",
active_mcp,
original_body,
);
}
// Execute tool
let call_result = execute_mcp_call(active_mcp, &tool_name, &args_json_str).await;
let output_str = match call_result {
Ok((_, output)) => output,
Err(err) => {
warn!("Tool execution failed: {}", err);
// Return error as output, let model decide how to proceed
json!({ "error": err }).to_string()
}
};
// Record the call
state.record_call(call_id, tool_name, args_json_str, output_str);
// Build resume payload
current_payload = build_resume_payload(
&base_payload,
&state.conversation_history,
&state.original_input,
&tools_json,
false, // is_streaming = false (non-streaming tool loop)
)?;
} else {
// No more tool calls, we're done
info!(
"Tool loop completed: {} iterations, {} total calls",
state.iteration, state.total_calls
);
// Inject MCP output items if we executed any tools
if state.total_calls > 0 {
let server_label = original_body
.tools
.iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp))
.and_then(|t| t.server_label.as_deref())
.unwrap_or("mcp");
// Build mcp_list_tools item
let list_tools_item = build_mcp_list_tools_item(active_mcp, server_label);
// Insert at beginning of output array
if let Some(output_array) = response_json
.get_mut("output")
.and_then(|v| v.as_array_mut())
{
output_array.insert(0, list_tools_item);
// Build mcp_call items using helper function
let mcp_call_items =
build_executed_mcp_call_items(&state.conversation_history, server_label);
// Insert mcp_call items after mcp_list_tools using mutable position
let mut insert_pos = 1;
for item in mcp_call_items {
output_array.insert(insert_pos, item);
insert_pos += 1;
}
}
}
return Ok(response_json);
}
}
}
/// Build an incomplete response when limits are exceeded
pub(super) fn build_incomplete_response(
mut response: Value,
state: ToolLoopState,
reason: &str,
active_mcp: &Arc<McpClientManager>,
original_body: &ResponsesRequest,
) -> Result<Value, String> {
let obj = response
.as_object_mut()
.ok_or_else(|| "response not an object".to_string())?;
// Set status to completed (not failed - partial success)
obj.insert("status".to_string(), Value::String("completed".to_string()));
// Set incomplete_details
obj.insert(
"incomplete_details".to_string(),
json!({ "reason": reason }),
);
// Convert any function_call in output to mcp_call format
if let Some(output_array) = obj.get_mut("output").and_then(|v| v.as_array_mut()) {
let server_label = original_body
.tools
.iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp))
.and_then(|t| t.server_label.as_deref())
.unwrap_or("mcp");
// Find any function_call items and convert them to mcp_call (incomplete)
let mut mcp_call_items = Vec::new();
for item in output_array.iter() {
let item_type = item.get("type").and_then(|t| t.as_str());
if item_type == Some(event_types::ITEM_TYPE_FUNCTION_TOOL_CALL)
|| item_type == Some(event_types::ITEM_TYPE_FUNCTION_CALL)
{
let tool_name = item.get("name").and_then(|v| v.as_str()).unwrap_or("");
let args = item
.get("arguments")
.and_then(|v| v.as_str())
.unwrap_or("{}");
// Mark as incomplete - not executed
let mcp_call_item = build_mcp_call_item(
tool_name,
args,
"", // No output - wasn't executed
server_label,
false, // Not successful
Some("Not executed - response stopped due to limit"),
);
mcp_call_items.push(mcp_call_item);
}
}
// Add mcp_list_tools and executed mcp_call items at the beginning
if state.total_calls > 0 || !mcp_call_items.is_empty() {
let list_tools_item = build_mcp_list_tools_item(active_mcp, server_label);
output_array.insert(0, list_tools_item);
// Add mcp_call items for executed calls using helper
let executed_items =
build_executed_mcp_call_items(&state.conversation_history, server_label);
let mut insert_pos = 1;
for item in executed_items {
output_array.insert(insert_pos, item);
insert_pos += 1;
}
// Add incomplete mcp_call items
for item in mcp_call_items {
output_array.insert(insert_pos, item);
insert_pos += 1;
}
}
}
// Add warning to metadata
if let Some(metadata_val) = obj.get_mut("metadata") {
if let Some(metadata_obj) = metadata_val.as_object_mut() {
if let Some(mcp_val) = metadata_obj.get_mut("mcp") {
if let Some(mcp_obj) = mcp_val.as_object_mut() {
mcp_obj.insert(
"truncation_warning".to_string(),
Value::String(format!(
"Loop terminated at {} iterations, {} total calls (reason: {})",
state.iteration, state.total_calls, reason
)),
);
}
}
}
}
Ok(response)
}
// ============================================================================
// Output Item Builders
// ============================================================================
/// Generate a unique ID for MCP output items (similar to OpenAI format)
pub(super) fn generate_mcp_id(prefix: &str) -> String {
use rand::RngCore;
let mut rng = rand::rng();
let mut bytes = [0u8; 30];
rng.fill_bytes(&mut bytes);
let hex_string: String = bytes.iter().map(|b| format!("{:02x}", b)).collect();
format!("{}_{}", prefix, hex_string)
}
/// Build an mcp_list_tools output item
pub(super) fn build_mcp_list_tools_item(mcp: &Arc<McpClientManager>, server_label: &str) -> Value {
let tools = mcp.list_tools();
let tools_json: Vec<Value> = tools
.iter()
.map(|t| {
json!({
"name": t.name,
"description": t.description,
"input_schema": t.parameters.clone().unwrap_or_else(|| json!({
"type": "object",
"properties": {},
"additionalProperties": false
})),
"annotations": {
"read_only": false
}
})
})
.collect();
json!({
"id": generate_mcp_id("mcpl"),
"type": event_types::ITEM_TYPE_MCP_LIST_TOOLS,
"server_label": server_label,
"tools": tools_json
})
}
/// Build an mcp_call output item
pub(super) fn build_mcp_call_item(
tool_name: &str,
arguments: &str,
output: &str,
server_label: &str,
success: bool,
error: Option<&str>,
) -> Value {
json!({
"id": generate_mcp_id("mcp"),
"type": event_types::ITEM_TYPE_MCP_CALL,
"status": if success { "completed" } else { "failed" },
"approval_request_id": Value::Null,
"arguments": arguments,
"error": error,
"name": tool_name,
"output": output,
"server_label": server_label
})
}
/// Helper function to build mcp_call items from executed tool calls in conversation history
pub(super) fn build_executed_mcp_call_items(
conversation_history: &[Value],
server_label: &str,
) -> Vec<Value> {
let mut mcp_call_items = Vec::new();
for item in conversation_history {
if item.get("type").and_then(|t| t.as_str()) == Some(event_types::ITEM_TYPE_FUNCTION_CALL) {
let call_id = item.get("call_id").and_then(|v| v.as_str()).unwrap_or("");
let tool_name = item.get("name").and_then(|v| v.as_str()).unwrap_or("");
let args = item
.get("arguments")
.and_then(|v| v.as_str())
.unwrap_or("{}");
// Find corresponding output
let output_item = conversation_history.iter().find(|o| {
o.get("type").and_then(|t| t.as_str()) == Some("function_call_output")
&& o.get("call_id").and_then(|c| c.as_str()) == Some(call_id)
});
let output_str = output_item
.and_then(|o| o.get("output").and_then(|v| v.as_str()))
.unwrap_or("{}");
// Check if output contains error by parsing JSON
let is_error = serde_json::from_str::<Value>(output_str)
.map(|v| v.get("error").is_some())
.unwrap_or(false);
let mcp_call_item = build_mcp_call_item(
tool_name,
args,
output_str,
server_label,
!is_error,
if is_error {
Some("Tool execution failed")
} else {
None
},
);
mcp_call_items.push(mcp_call_item);
}
}
mcp_call_items
}
// ============================================================================
// Helper Functions
// ============================================================================
/// Extract function call from a response
pub(super) fn extract_function_call(resp: &Value) -> Option<(String, String, String)> {
let output = resp.get("output")?.as_array()?;
for item in output {
let obj = item.as_object()?;
let t = obj.get("type")?.as_str()?;
if t == event_types::ITEM_TYPE_FUNCTION_TOOL_CALL
|| t == event_types::ITEM_TYPE_FUNCTION_CALL
{
let call_id = obj
.get("call_id")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.or_else(|| {
obj.get("id")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
})?;
let name = obj.get("name")?.as_str()?.to_string();
let arguments = obj.get("arguments")?.as_str()?.to_string();
return Some((call_id, name, arguments));
}
}
None
}

View File

@@ -0,0 +1,18 @@
//! OpenAI-compatible router implementation
//!
//! This module provides OpenAI-compatible API routing with support for:
//! - Streaming and non-streaming responses
//! - MCP (Model Context Protocol) tool calling
//! - Response storage and conversation management
//! - Multi-turn tool execution loops
//! - SSE (Server-Sent Events) streaming
mod conversations;
mod mcp;
mod responses;
mod router;
mod streaming;
mod utils;
// Re-export the main router type for external use
pub use router::OpenAIRouter;

View File

@@ -0,0 +1,368 @@
//! Response storage, patching, and extraction utilities
use crate::data_connector::{ResponseId, SharedResponseStorage, StoredResponse};
use crate::protocols::spec::{ResponseInput, ResponseToolType, ResponsesRequest};
use serde_json::{json, Value};
use std::collections::HashMap;
use tracing::{info, warn};
use super::utils::event_types;
// ============================================================================
// Response Storage Operations
// ============================================================================
/// Store a response internally (checks if storage is enabled)
pub(super) async fn store_response_internal(
response_storage: &SharedResponseStorage,
response_json: &Value,
original_body: &ResponsesRequest,
) -> Result<(), String> {
if !original_body.store {
return Ok(());
}
match store_response_impl(response_storage, response_json, original_body).await {
Ok(response_id) => {
info!(response_id = %response_id.0, "Stored response locally");
Ok(())
}
Err(e) => Err(e),
}
}
/// Build a StoredResponse from response JSON and original request
pub(super) fn build_stored_response(
response_json: &Value,
original_body: &ResponsesRequest,
) -> StoredResponse {
let input_text = match &original_body.input {
ResponseInput::Text(text) => text.clone(),
ResponseInput::Items(_) => "complex input".to_string(),
};
let output_text = extract_primary_output_text(response_json).unwrap_or_default();
let mut stored_response = StoredResponse::new(input_text, output_text, None);
stored_response.instructions = response_json
.get("instructions")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.or_else(|| original_body.instructions.clone());
stored_response.model = response_json
.get("model")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.or_else(|| original_body.model.clone());
stored_response.user = response_json
.get("user")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.or_else(|| original_body.user.clone());
// Set conversation id from request if provided
if let Some(conv_id) = original_body.conversation.clone() {
stored_response.conversation_id = Some(conv_id);
}
stored_response.metadata = response_json
.get("metadata")
.and_then(|v| v.as_object())
.map(|m| {
m.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect::<HashMap<_, _>>()
})
.unwrap_or_else(|| original_body.metadata.clone().unwrap_or_default());
stored_response.previous_response_id = response_json
.get("previous_response_id")
.and_then(|v| v.as_str())
.map(ResponseId::from)
.or_else(|| {
original_body
.previous_response_id
.as_ref()
.map(|id| ResponseId::from(id.as_str()))
});
if let Some(id_str) = response_json.get("id").and_then(|v| v.as_str()) {
stored_response.id = ResponseId::from(id_str);
}
stored_response.raw_response = response_json.clone();
stored_response
}
/// Store response implementation (public for use across modules)
pub(super) async fn store_response_impl(
response_storage: &SharedResponseStorage,
response_json: &Value,
original_body: &ResponsesRequest,
) -> Result<ResponseId, String> {
let stored_response = build_stored_response(response_json, original_body);
response_storage
.store_response(stored_response)
.await
.map_err(|e| format!("Failed to store response: {}", e))
}
// ============================================================================
// Response JSON Patching
// ============================================================================
/// Patch streaming response JSON with metadata from original request
pub(super) fn patch_streaming_response_json(
response_json: &mut Value,
original_body: &ResponsesRequest,
original_previous_response_id: Option<&str>,
) {
if let Some(obj) = response_json.as_object_mut() {
if let Some(prev_id) = original_previous_response_id {
let should_insert = obj
.get("previous_response_id")
.map(|v| v.is_null() || v.as_str().map(|s| s.is_empty()).unwrap_or(false))
.unwrap_or(true);
if should_insert {
obj.insert(
"previous_response_id".to_string(),
Value::String(prev_id.to_string()),
);
}
}
if !obj.contains_key("instructions")
|| obj
.get("instructions")
.map(|v| v.is_null())
.unwrap_or(false)
{
if let Some(instructions) = &original_body.instructions {
obj.insert(
"instructions".to_string(),
Value::String(instructions.clone()),
);
}
}
if !obj.contains_key("metadata")
|| obj.get("metadata").map(|v| v.is_null()).unwrap_or(false)
{
if let Some(metadata) = &original_body.metadata {
let metadata_map: serde_json::Map<String, Value> = metadata
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
obj.insert("metadata".to_string(), Value::Object(metadata_map));
}
}
obj.insert("store".to_string(), Value::Bool(original_body.store));
if obj
.get("model")
.and_then(|v| v.as_str())
.map(|s| s.is_empty())
.unwrap_or(true)
{
if let Some(model) = &original_body.model {
obj.insert("model".to_string(), Value::String(model.clone()));
}
}
if obj.get("user").map(|v| v.is_null()).unwrap_or(false) {
if let Some(user) = &original_body.user {
obj.insert("user".to_string(), Value::String(user.clone()));
}
}
// Attach conversation id for client response if present (final aggregated JSON)
if let Some(conv_id) = original_body.conversation.clone() {
obj.insert("conversation".to_string(), json!({"id": conv_id}));
}
}
}
/// Rewrite streaming SSE block to include metadata from original request
pub(super) fn rewrite_streaming_block(
block: &str,
original_body: &ResponsesRequest,
original_previous_response_id: Option<&str>,
) -> Option<String> {
let trimmed = block.trim();
if trimmed.is_empty() {
return None;
}
let mut data_lines: Vec<String> = Vec::new();
for line in trimmed.lines() {
if line.starts_with("data:") {
data_lines.push(line.trim_start_matches("data:").trim_start().to_string());
}
}
if data_lines.is_empty() {
return None;
}
let payload = data_lines.join("\n");
let mut parsed: Value = match serde_json::from_str(&payload) {
Ok(value) => value,
Err(err) => {
warn!("Failed to parse streaming JSON payload: {}", err);
return None;
}
};
let event_type = parsed
.get("type")
.and_then(|v| v.as_str())
.unwrap_or_default();
let should_patch = matches!(
event_type,
event_types::RESPONSE_CREATED
| event_types::RESPONSE_IN_PROGRESS
| event_types::RESPONSE_COMPLETED
);
if !should_patch {
return None;
}
let mut changed = false;
if let Some(response_obj) = parsed.get_mut("response").and_then(|v| v.as_object_mut()) {
let desired_store = Value::Bool(original_body.store);
if response_obj.get("store") != Some(&desired_store) {
response_obj.insert("store".to_string(), desired_store);
changed = true;
}
if let Some(prev_id) = original_previous_response_id {
let needs_previous = response_obj
.get("previous_response_id")
.map(|v| v.is_null() || v.as_str().map(|s| s.is_empty()).unwrap_or(false))
.unwrap_or(true);
if needs_previous {
response_obj.insert(
"previous_response_id".to_string(),
Value::String(prev_id.to_string()),
);
changed = true;
}
}
// Attach conversation id into streaming event response content with ordering
if let Some(conv_id) = original_body.conversation.clone() {
response_obj.insert("conversation".to_string(), json!({"id": conv_id}));
changed = true;
}
}
if !changed {
return None;
}
let new_payload = match serde_json::to_string(&parsed) {
Ok(json) => json,
Err(err) => {
warn!("Failed to serialize modified streaming payload: {}", err);
return None;
}
};
let mut rebuilt_lines = Vec::new();
let mut data_written = false;
for line in trimmed.lines() {
if line.starts_with("data:") {
if !data_written {
rebuilt_lines.push(format!("data: {}", new_payload));
data_written = true;
}
} else {
rebuilt_lines.push(line.to_string());
}
}
if !data_written {
rebuilt_lines.push(format!("data: {}", new_payload));
}
Some(rebuilt_lines.join("\n"))
}
/// Mask function tools as MCP tools in response for client
pub(super) fn mask_tools_as_mcp(resp: &mut Value, original_body: &ResponsesRequest) {
let mcp_tool = original_body
.tools
.iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp) && t.server_url.is_some());
let Some(t) = mcp_tool else {
return;
};
let mut m = serde_json::Map::new();
m.insert("type".to_string(), Value::String("mcp".to_string()));
if let Some(label) = &t.server_label {
m.insert("server_label".to_string(), Value::String(label.clone()));
}
if let Some(url) = &t.server_url {
m.insert("server_url".to_string(), Value::String(url.clone()));
}
if let Some(desc) = &t.server_description {
m.insert(
"server_description".to_string(),
Value::String(desc.clone()),
);
}
if let Some(req) = &t.require_approval {
m.insert("require_approval".to_string(), Value::String(req.clone()));
}
if let Some(allowed) = &t.allowed_tools {
m.insert(
"allowed_tools".to_string(),
Value::Array(allowed.iter().map(|s| Value::String(s.clone())).collect()),
);
}
if let Some(obj) = resp.as_object_mut() {
obj.insert("tools".to_string(), Value::Array(vec![Value::Object(m)]));
obj.entry("tool_choice")
.or_insert(Value::String("auto".to_string()));
}
}
// ============================================================================
// Output Text Extraction
// ============================================================================
/// Extract primary output text from response JSON
pub(super) fn extract_primary_output_text(response_json: &Value) -> Option<String> {
if let Some(items) = response_json.get("output").and_then(|v| v.as_array()) {
for item in items {
if let Some(content) = item.get("content").and_then(|v| v.as_array()) {
for part in content {
if part
.get("type")
.and_then(|v| v.as_str())
.map(|t| t == "output_text")
.unwrap_or(false)
{
if let Some(text) = part.get("text").and_then(|v| v.as_str()) {
return Some(text.to_string());
}
}
}
}
}
}
None
}

View File

@@ -0,0 +1,909 @@
//! OpenAI router - main coordinator that delegates to specialized modules
use crate::config::CircuitBreakerConfig;
use crate::core::{CircuitBreaker, CircuitBreakerConfig as CoreCircuitBreakerConfig};
use crate::data_connector::{
conversation_items::ListParams, conversation_items::SortOrder, ConversationId, ResponseId,
SharedConversationItemStorage, SharedConversationStorage, SharedResponseStorage,
};
use crate::protocols::spec::{
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
ResponseContentPart, ResponseInput, ResponseInputOutputItem, ResponsesGetParams,
ResponsesRequest,
};
use crate::routers::header_utils::apply_request_headers;
use axum::{
body::Body,
extract::Request,
http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode},
response::{IntoResponse, Response},
Json,
};
use futures_util::StreamExt;
use serde_json::{json, to_value, Value};
use std::{
any::Any,
sync::{atomic::AtomicBool, Arc},
};
use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{info, warn};
// Import from sibling modules
use super::conversations::{
create_conversation, delete_conversation, get_conversation, list_conversation_items,
persist_conversation_items, update_conversation,
};
use super::mcp::{
execute_tool_loop, mcp_manager_from_request_tools, prepare_mcp_payload_for_streaming,
McpLoopConfig,
};
use super::responses::{mask_tools_as_mcp, patch_streaming_response_json, store_response_internal};
use super::streaming::handle_streaming_response;
// ============================================================================
// OpenAIRouter Struct
// ============================================================================
/// Router for OpenAI backend
pub struct OpenAIRouter {
/// HTTP client for upstream OpenAI-compatible API
client: reqwest::Client,
/// Base URL for identification (no trailing slash)
base_url: String,
/// Circuit breaker
circuit_breaker: CircuitBreaker,
/// Health status
healthy: AtomicBool,
/// Response storage for managing conversation history
response_storage: SharedResponseStorage,
/// Conversation storage backend
conversation_storage: SharedConversationStorage,
/// Conversation item storage backend
conversation_item_storage: SharedConversationItemStorage,
/// Optional MCP manager (enabled via config presence)
mcp_manager: Option<Arc<crate::mcp::McpClientManager>>,
}
impl std::fmt::Debug for OpenAIRouter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OpenAIRouter")
.field("base_url", &self.base_url)
.field("healthy", &self.healthy)
.finish()
}
}
impl OpenAIRouter {
/// Maximum number of conversation items to attach as input when a conversation is provided
const MAX_CONVERSATION_HISTORY_ITEMS: usize = 100;
/// Create a new OpenAI router
pub async fn new(
base_url: String,
circuit_breaker_config: Option<CircuitBreakerConfig>,
response_storage: SharedResponseStorage,
conversation_storage: SharedConversationStorage,
conversation_item_storage: SharedConversationItemStorage,
) -> Result<Self, String> {
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(300))
.build()
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
let base_url = base_url.trim_end_matches('/').to_string();
// Convert circuit breaker config
let core_cb_config = circuit_breaker_config
.map(|cb| CoreCircuitBreakerConfig {
failure_threshold: cb.failure_threshold,
success_threshold: cb.success_threshold,
timeout_duration: std::time::Duration::from_secs(cb.timeout_duration_secs),
window_duration: std::time::Duration::from_secs(cb.window_duration_secs),
})
.unwrap_or_default();
let circuit_breaker = CircuitBreaker::with_config(core_cb_config);
// Optional MCP manager activation via env var path (config-driven gate)
let mcp_manager = match std::env::var("SGLANG_MCP_CONFIG").ok() {
Some(path) if !path.trim().is_empty() => {
match crate::mcp::McpConfig::from_file(&path).await {
Ok(cfg) => match crate::mcp::McpClientManager::new(cfg).await {
Ok(mgr) => Some(Arc::new(mgr)),
Err(err) => {
warn!("Failed to initialize MCP manager: {}", err);
None
}
},
Err(err) => {
warn!("Failed to load MCP config from '{}': {}", path, err);
None
}
}
}
_ => None,
};
Ok(Self {
client,
base_url,
circuit_breaker,
healthy: AtomicBool::new(true),
response_storage,
conversation_storage,
conversation_item_storage,
mcp_manager,
})
}
/// Handle non-streaming response with optional MCP tool loop
async fn handle_non_streaming_response(
&self,
url: String,
headers: Option<&HeaderMap>,
mut payload: Value,
original_body: &ResponsesRequest,
original_previous_response_id: Option<String>,
) -> Response {
// Check if MCP is active for this request
let req_mcp_manager = mcp_manager_from_request_tools(&original_body.tools).await;
let active_mcp = req_mcp_manager.as_ref().or(self.mcp_manager.as_ref());
let mut response_json: Value;
// If MCP is active, execute tool loop
if let Some(mcp) = active_mcp {
let config = McpLoopConfig::default();
// Transform MCP tools to function tools
prepare_mcp_payload_for_streaming(&mut payload, mcp);
match execute_tool_loop(
&self.client,
&url,
headers,
payload,
original_body,
mcp,
&config,
)
.await
{
Ok(resp) => response_json = resp,
Err(err) => {
self.circuit_breaker.record_failure();
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": {"message": err}})),
)
.into_response();
}
}
} else {
// No MCP - simple request
let mut request_builder = self.client.post(&url).json(&payload);
if let Some(h) = headers {
request_builder = apply_request_headers(h, request_builder, true);
}
let response = match request_builder.send().await {
Ok(r) => r,
Err(e) => {
self.circuit_breaker.record_failure();
return (
StatusCode::BAD_GATEWAY,
format!("Failed to forward request to OpenAI: {}", e),
)
.into_response();
}
};
if !response.status().is_success() {
self.circuit_breaker.record_failure();
let status = StatusCode::from_u16(response.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
let body = response.text().await.unwrap_or_default();
return (status, body).into_response();
}
response_json = match response.json::<Value>().await {
Ok(r) => r,
Err(e) => {
self.circuit_breaker.record_failure();
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to parse upstream response: {}", e),
)
.into_response();
}
};
self.circuit_breaker.record_success();
}
// Patch response with metadata
mask_tools_as_mcp(&mut response_json, original_body);
patch_streaming_response_json(
&mut response_json,
original_body,
original_previous_response_id.as_deref(),
);
// Persist conversation items if conversation is provided
if original_body.conversation.is_some() {
if let Err(err) = persist_conversation_items(
self.conversation_storage.clone(),
self.conversation_item_storage.clone(),
self.response_storage.clone(),
&response_json,
original_body,
)
.await
{
warn!("Failed to persist conversation items: {}", err);
}
} else {
// Store response only if no conversation (persist_conversation_items already stores it)
if let Err(err) =
store_response_internal(&self.response_storage, &response_json, original_body).await
{
warn!("Failed to store response: {}", err);
}
}
(StatusCode::OK, Json(response_json)).into_response()
}
}
// ============================================================================
// RouterTrait Implementation
// ============================================================================
#[async_trait::async_trait]
impl crate::routers::RouterTrait for OpenAIRouter {
fn as_any(&self) -> &dyn Any {
self
}
async fn health_generate(&self, _req: Request<Body>) -> Response {
// Simple upstream probe: GET {base}/v1/models without auth
let url = format!("{}/v1/models", self.base_url);
match self
.client
.get(&url)
.timeout(std::time::Duration::from_secs(2))
.send()
.await
{
Ok(resp) => {
let code = resp.status();
// Treat success and auth-required as healthy (endpoint reachable)
if code.is_success() || code.as_u16() == 401 || code.as_u16() == 403 {
(StatusCode::OK, "OK").into_response()
} else {
(
StatusCode::SERVICE_UNAVAILABLE,
format!("Upstream status: {}", code),
)
.into_response()
}
}
Err(e) => (
StatusCode::SERVICE_UNAVAILABLE,
format!("Upstream error: {}", e),
)
.into_response(),
}
}
async fn get_server_info(&self, _req: Request<Body>) -> Response {
let info = json!({
"router_type": "openai",
"workers": 1,
"base_url": &self.base_url
});
(StatusCode::OK, info.to_string()).into_response()
}
async fn get_models(&self, req: Request<Body>) -> Response {
// Proxy to upstream /v1/models; forward Authorization header if provided
let headers = req.headers();
let mut upstream = self.client.get(format!("{}/v1/models", self.base_url));
if let Some(auth) = headers
.get("authorization")
.or_else(|| headers.get("Authorization"))
{
upstream = upstream.header("Authorization", auth);
}
match upstream.send().await {
Ok(res) => {
let status = StatusCode::from_u16(res.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
let content_type = res.headers().get(CONTENT_TYPE).cloned();
match res.bytes().await {
Ok(body) => {
let mut response = Response::new(Body::from(body));
*response.status_mut() = status;
if let Some(ct) = content_type {
response.headers_mut().insert(CONTENT_TYPE, ct);
}
response
}
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to read upstream response: {}", e),
)
.into_response(),
}
}
Err(e) => (
StatusCode::BAD_GATEWAY,
format!("Failed to contact upstream: {}", e),
)
.into_response(),
}
}
async fn get_model_info(&self, _req: Request<Body>) -> Response {
// Not directly supported without model param; return 501
(
StatusCode::NOT_IMPLEMENTED,
"get_model_info not implemented for OpenAI router",
)
.into_response()
}
async fn route_generate(
&self,
_headers: Option<&HeaderMap>,
_body: &GenerateRequest,
_model_id: Option<&str>,
) -> Response {
// Generate endpoint is SGLang-specific, not supported for OpenAI backend
(
StatusCode::NOT_IMPLEMENTED,
"Generate endpoint not supported for OpenAI backend",
)
.into_response()
}
async fn route_chat(
&self,
headers: Option<&HeaderMap>,
body: &ChatCompletionRequest,
_model_id: Option<&str>,
) -> Response {
if !self.circuit_breaker.can_execute() {
return (StatusCode::SERVICE_UNAVAILABLE, "Circuit breaker open").into_response();
}
// Serialize request body, removing SGLang-only fields
let mut payload = match to_value(body) {
Ok(v) => v,
Err(e) => {
return (
StatusCode::BAD_REQUEST,
format!("Failed to serialize request: {}", e),
)
.into_response();
}
};
if let Some(obj) = payload.as_object_mut() {
for key in [
"top_k",
"min_p",
"min_tokens",
"regex",
"ebnf",
"stop_token_ids",
"no_stop_trim",
"ignore_eos",
"continue_final_message",
"skip_special_tokens",
"lora_path",
"session_params",
"separate_reasoning",
"stream_reasoning",
"chat_template_kwargs",
"return_hidden_states",
"repetition_penalty",
"sampling_seed",
] {
obj.remove(key);
}
}
let url = format!("{}/v1/chat/completions", self.base_url);
let mut req = self.client.post(&url).json(&payload);
// Forward Authorization header if provided
if let Some(h) = headers {
if let Some(auth) = h.get("authorization").or_else(|| h.get("Authorization")) {
req = req.header("Authorization", auth);
}
}
// Accept SSE when stream=true
if body.stream {
req = req.header("Accept", "text/event-stream");
}
let resp = match req.send().await {
Ok(r) => r,
Err(e) => {
self.circuit_breaker.record_failure();
return (
StatusCode::SERVICE_UNAVAILABLE,
format!("Failed to contact upstream: {}", e),
)
.into_response();
}
};
let status = StatusCode::from_u16(resp.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
if !body.stream {
// Capture Content-Type before consuming response body
let content_type = resp.headers().get(CONTENT_TYPE).cloned();
match resp.bytes().await {
Ok(body) => {
self.circuit_breaker.record_success();
let mut response = Response::new(Body::from(body));
*response.status_mut() = status;
if let Some(ct) = content_type {
response.headers_mut().insert(CONTENT_TYPE, ct);
}
response
}
Err(e) => {
self.circuit_breaker.record_failure();
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to read response: {}", e),
)
.into_response()
}
}
} else {
// Stream SSE bytes to client
let stream = resp.bytes_stream();
let (tx, rx) = mpsc::unbounded_channel();
tokio::spawn(async move {
let mut s = stream;
while let Some(chunk) = s.next().await {
match chunk {
Ok(bytes) => {
if tx.send(Ok(bytes)).is_err() {
break;
}
}
Err(e) => {
let _ = tx.send(Err(format!("Stream error: {}", e)));
break;
}
}
}
});
let mut response = Response::new(Body::from_stream(UnboundedReceiverStream::new(rx)));
*response.status_mut() = status;
response
.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
response
}
}
async fn route_completion(
&self,
_headers: Option<&HeaderMap>,
_body: &CompletionRequest,
_model_id: Option<&str>,
) -> Response {
// Completion endpoint not implemented for OpenAI backend
(
StatusCode::NOT_IMPLEMENTED,
"Completion endpoint not implemented for OpenAI backend",
)
.into_response()
}
async fn route_responses(
&self,
headers: Option<&HeaderMap>,
body: &ResponsesRequest,
model_id: Option<&str>,
) -> Response {
let url = format!("{}/v1/responses", self.base_url);
info!(
requested_store = body.store,
is_streaming = body.stream,
"openai_responses_request"
);
// Validate mutually exclusive params: previous_response_id and conversation
// TODO: this validation logic should move the right place, also we need a proper error message module
if body.previous_response_id.is_some() && body.conversation.is_some() {
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": {
"message": "Mutually exclusive parameters. Ensure you are only providing one of: 'previous_response_id' or 'conversation'.",
"type": "invalid_request_error",
"param": Value::Null,
"code": "mutually_exclusive_parameters"
}
})),
)
.into_response();
}
// Clone the body and override model if needed
let mut request_body = body.clone();
if let Some(model) = model_id {
request_body.model = Some(model.to_string());
}
// Do not forward conversation field upstream; retain for local persistence only
request_body.conversation = None;
// Store the original previous_response_id for the response
let original_previous_response_id = request_body.previous_response_id.clone();
// Handle previous_response_id by loading prior context
let mut conversation_items: Option<Vec<ResponseInputOutputItem>> = None;
if let Some(prev_id_str) = request_body.previous_response_id.clone() {
let prev_id = ResponseId::from(prev_id_str.as_str());
match self
.response_storage
.get_response_chain(&prev_id, None)
.await
{
Ok(chain) => {
let mut items = Vec::new();
for stored in chain.responses.iter() {
// Convert input to conversation item
items.push(ResponseInputOutputItem::Message {
id: format!("msg_u_{}", stored.id.0.trim_start_matches("resp_")),
role: "user".to_string(),
content: vec![ResponseContentPart::InputText {
text: stored.input.clone(),
}],
status: Some("completed".to_string()),
});
// Convert output to conversation items directly from stored response
if let Some(output_arr) =
stored.raw_response.get("output").and_then(|v| v.as_array())
{
for item in output_arr {
if let Ok(output_item) =
serde_json::from_value::<ResponseInputOutputItem>(item.clone())
{
items.push(output_item);
}
}
}
}
conversation_items = Some(items);
request_body.previous_response_id = None;
}
Err(e) => {
warn!(
"Failed to load previous response chain for {}: {}",
prev_id_str, e
);
}
}
}
// Handle conversation by loading history
if let Some(conv_id_str) = body.conversation.clone() {
let conv_id = ConversationId::from(conv_id_str.as_str());
// Verify conversation exists
if let Ok(None) = self.conversation_storage.get_conversation(&conv_id).await {
return (
StatusCode::NOT_FOUND,
Json(json!({"error": "Conversation not found"})),
)
.into_response();
}
// Load conversation history (ascending order for chronological context)
let params = ListParams {
limit: Self::MAX_CONVERSATION_HISTORY_ITEMS,
order: SortOrder::Asc,
after: None,
};
match self
.conversation_item_storage
.list_items(&conv_id, params)
.await
{
Ok(stored_items) => {
let mut items: Vec<ResponseInputOutputItem> = Vec::new();
for item in stored_items.into_iter() {
// Only use message items for conversation context
// Skip non-message items (reasoning, function calls, etc.)
if item.item_type == "message" {
if let Ok(content_parts) =
serde_json::from_value::<Vec<ResponseContentPart>>(
item.content.clone(),
)
{
items.push(ResponseInputOutputItem::Message {
id: item.id.0.clone(),
role: item.role.clone().unwrap_or_else(|| "user".to_string()),
content: content_parts,
status: item.status.clone(),
});
}
}
}
// Append current request
match &request_body.input {
ResponseInput::Text(text) => {
items.push(ResponseInputOutputItem::Message {
id: format!("msg_u_{}", conv_id.0),
role: "user".to_string(),
content: vec![ResponseContentPart::InputText {
text: text.clone(),
}],
status: Some("completed".to_string()),
});
}
ResponseInput::Items(current_items) => {
items.extend_from_slice(current_items);
}
}
request_body.input = ResponseInput::Items(items);
}
Err(e) => {
warn!("Failed to load conversation history: {}", e);
}
}
}
// If we have conversation_items from previous_response_id, use them
if let Some(mut items) = conversation_items {
// Append current request
match &request_body.input {
ResponseInput::Text(text) => {
items.push(ResponseInputOutputItem::Message {
id: format!(
"msg_u_{}",
original_previous_response_id
.as_ref()
.unwrap_or(&"new".to_string())
),
role: "user".to_string(),
content: vec![ResponseContentPart::InputText { text: text.clone() }],
status: Some("completed".to_string()),
});
}
ResponseInput::Items(current_items) => {
items.extend_from_slice(current_items);
}
}
request_body.input = ResponseInput::Items(items);
}
// Always set store=false for upstream (we store internally)
request_body.store = false;
// Convert to JSON and strip SGLang-specific fields
let mut payload = match to_value(&request_body) {
Ok(v) => v,
Err(e) => {
return (
StatusCode::BAD_REQUEST,
format!("Failed to serialize request: {}", e),
)
.into_response();
}
};
// Remove SGLang-specific fields
if let Some(obj) = payload.as_object_mut() {
for key in [
"request_id",
"priority",
"top_k",
"frequency_penalty",
"presence_penalty",
"min_p",
"min_tokens",
"regex",
"ebnf",
"stop_token_ids",
"no_stop_trim",
"ignore_eos",
"continue_final_message",
"skip_special_tokens",
"lora_path",
"session_params",
"separate_reasoning",
"stream_reasoning",
"chat_template_kwargs",
"return_hidden_states",
"repetition_penalty",
"sampling_seed",
] {
obj.remove(key);
}
}
// Delegate to streaming or non-streaming handler
if body.stream {
handle_streaming_response(
&self.client,
&self.circuit_breaker,
self.mcp_manager.as_ref(),
self.response_storage.clone(),
self.conversation_storage.clone(),
self.conversation_item_storage.clone(),
url,
headers,
payload,
body,
original_previous_response_id,
)
.await
} else {
self.handle_non_streaming_response(
url,
headers,
payload,
body,
original_previous_response_id,
)
.await
}
}
async fn get_response(
&self,
_headers: Option<&HeaderMap>,
response_id: &str,
_params: &ResponsesGetParams,
) -> Response {
let id = ResponseId::from(response_id);
match self.response_storage.get_response(&id).await {
Ok(Some(stored)) => {
let mut response_json = stored.raw_response;
if let Some(obj) = response_json.as_object_mut() {
obj.insert("id".to_string(), json!(id.0));
}
(StatusCode::OK, Json(response_json)).into_response()
}
Ok(None) => (
StatusCode::NOT_FOUND,
Json(json!({"error": "Response not found"})),
)
.into_response(),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to get response: {}", e)})),
)
.into_response(),
}
}
async fn cancel_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response {
// Forward cancellation to upstream
let url = format!("{}/v1/responses/{}/cancel", self.base_url, response_id);
let mut req = self.client.post(&url);
if let Some(h) = headers {
req = apply_request_headers(h, req, false);
}
match req.send().await {
Ok(resp) => {
let status = StatusCode::from_u16(resp.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
match resp.text().await {
Ok(body) => (status, body).into_response(),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to read response: {}", e),
)
.into_response(),
}
}
Err(e) => (
StatusCode::BAD_GATEWAY,
format!("Failed to contact upstream: {}", e),
)
.into_response(),
}
}
async fn route_embeddings(
&self,
_headers: Option<&HeaderMap>,
_body: &EmbeddingRequest,
_model_id: Option<&str>,
) -> Response {
(StatusCode::NOT_IMPLEMENTED, "Embeddings not supported").into_response()
}
async fn route_rerank(
&self,
_headers: Option<&HeaderMap>,
_body: &RerankRequest,
_model_id: Option<&str>,
) -> Response {
(StatusCode::NOT_IMPLEMENTED, "Rerank not supported").into_response()
}
async fn create_conversation(&self, _headers: Option<&HeaderMap>, body: &Value) -> Response {
create_conversation(&self.conversation_storage, body.clone()).await
}
async fn get_conversation(
&self,
_headers: Option<&HeaderMap>,
conversation_id: &str,
) -> Response {
get_conversation(&self.conversation_storage, conversation_id).await
}
async fn update_conversation(
&self,
_headers: Option<&HeaderMap>,
conversation_id: &str,
body: &Value,
) -> Response {
update_conversation(&self.conversation_storage, conversation_id, body.clone()).await
}
async fn delete_conversation(
&self,
_headers: Option<&HeaderMap>,
conversation_id: &str,
) -> Response {
delete_conversation(&self.conversation_storage, conversation_id).await
}
fn router_type(&self) -> &'static str {
"openai"
}
async fn list_conversation_items(
&self,
_headers: Option<&HeaderMap>,
conversation_id: &str,
limit: Option<usize>,
order: Option<String>,
after: Option<String>,
) -> Response {
let mut query_params = std::collections::HashMap::new();
query_params.insert("limit".to_string(), limit.unwrap_or(100).to_string());
if let Some(after_val) = after {
if !after_val.is_empty() {
query_params.insert("after".to_string(), after_val);
}
}
if let Some(order_val) = order {
query_params.insert("order".to_string(), order_val);
}
list_conversation_items(
&self.conversation_storage,
&self.conversation_item_storage,
conversation_id,
query_params,
)
.await
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,100 @@
//! Utility types and constants for OpenAI router
use std::collections::HashMap;
// ============================================================================
// SSE Event Type Constants
// ============================================================================
/// SSE event type constants - single source of truth for event type strings
pub(crate) mod event_types {
// Response lifecycle events
pub const RESPONSE_CREATED: &str = "response.created";
pub const RESPONSE_IN_PROGRESS: &str = "response.in_progress";
pub const RESPONSE_COMPLETED: &str = "response.completed";
// Output item events
pub const OUTPUT_ITEM_ADDED: &str = "response.output_item.added";
pub const OUTPUT_ITEM_DONE: &str = "response.output_item.done";
pub const OUTPUT_ITEM_DELTA: &str = "response.output_item.delta";
// Function call events
pub const FUNCTION_CALL_ARGUMENTS_DELTA: &str = "response.function_call_arguments.delta";
pub const FUNCTION_CALL_ARGUMENTS_DONE: &str = "response.function_call_arguments.done";
// MCP call events
pub const MCP_CALL_ARGUMENTS_DELTA: &str = "response.mcp_call_arguments.delta";
pub const MCP_CALL_ARGUMENTS_DONE: &str = "response.mcp_call_arguments.done";
pub const MCP_CALL_IN_PROGRESS: &str = "response.mcp_call.in_progress";
pub const MCP_CALL_COMPLETED: &str = "response.mcp_call.completed";
pub const MCP_LIST_TOOLS_IN_PROGRESS: &str = "response.mcp_list_tools.in_progress";
pub const MCP_LIST_TOOLS_COMPLETED: &str = "response.mcp_list_tools.completed";
// Item types
pub const ITEM_TYPE_FUNCTION_CALL: &str = "function_call";
pub const ITEM_TYPE_FUNCTION_TOOL_CALL: &str = "function_tool_call";
pub const ITEM_TYPE_MCP_CALL: &str = "mcp_call";
pub const ITEM_TYPE_FUNCTION: &str = "function";
pub const ITEM_TYPE_MCP_LIST_TOOLS: &str = "mcp_list_tools";
}
// ============================================================================
// Stream Action Enum
// ============================================================================
/// Action to take based on streaming event processing
#[derive(Debug)]
pub(crate) enum StreamAction {
Forward, // Pass event to client
Buffer, // Accumulate for tool execution
ExecuteTools, // Function call complete, execute now
}
// ============================================================================
// Output Index Mapper
// ============================================================================
/// Maps upstream output indices to sequential downstream indices
#[derive(Debug, Default)]
pub(crate) struct OutputIndexMapper {
next_index: usize,
// Map upstream output_index -> remapped output_index
assigned: HashMap<usize, usize>,
}
impl OutputIndexMapper {
pub fn with_start(next_index: usize) -> Self {
Self {
next_index,
assigned: HashMap::new(),
}
}
pub fn ensure_mapping(&mut self, upstream_index: usize) -> usize {
*self.assigned.entry(upstream_index).or_insert_with(|| {
let assigned = self.next_index;
self.next_index += 1;
assigned
})
}
pub fn lookup(&self, upstream_index: usize) -> Option<usize> {
self.assigned.get(&upstream_index).copied()
}
pub fn allocate_synthetic(&mut self) -> usize {
let assigned = self.next_index;
self.next_index += 1;
assigned
}
pub fn next_index(&self) -> usize {
self.next_index
}
}
// ============================================================================
// Re-export FunctionCallInProgress from mcp module
// ============================================================================
pub(crate) use super::mcp::FunctionCallInProgress;

View File

@@ -22,7 +22,7 @@ use sglang_router_rs::{
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, ResponseInput,
ResponsesGetParams, ResponsesRequest, UserMessageContent,
},
routers::{openai_router::OpenAIRouter, RouterTrait},
routers::{openai::OpenAIRouter, RouterTrait},
};
use std::collections::HashMap;
use std::sync::{