[router] Refactor OpenAI router: split monolithic file and move location (#11359)
This commit is contained in:
@@ -3,7 +3,8 @@
|
||||
use super::grpc::pd_router::GrpcPDRouter;
|
||||
use super::grpc::router::GrpcRouter;
|
||||
use super::{
|
||||
http::{openai_router::OpenAIRouter, pd_router::PDRouter, router::Router},
|
||||
http::{pd_router::PDRouter, router::Router},
|
||||
openai::OpenAIRouter,
|
||||
RouterTrait,
|
||||
};
|
||||
use crate::config::{ConnectionMode, PolicyConfig, RoutingMode};
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
//! HTTP router implementations
|
||||
|
||||
pub mod openai_router;
|
||||
pub mod pd_router;
|
||||
pub mod pd_types;
|
||||
pub mod router;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -19,12 +19,13 @@ pub mod factory;
|
||||
pub mod grpc;
|
||||
pub mod header_utils;
|
||||
pub mod http;
|
||||
pub mod openai; // New refactored OpenAI router module
|
||||
pub mod router_manager;
|
||||
|
||||
pub use factory::RouterFactory;
|
||||
|
||||
// Re-export HTTP routers for convenience (keeps routers::openai_router path working)
|
||||
pub use http::{openai_router, pd_router, pd_types, router};
|
||||
// Re-export HTTP routers for convenience
|
||||
pub use http::{pd_router, pd_types, router};
|
||||
|
||||
/// Core trait for all router implementations
|
||||
///
|
||||
|
||||
574
sgl-router/src/routers/openai/conversations.rs
Normal file
574
sgl-router/src/routers/openai/conversations.rs
Normal file
@@ -0,0 +1,574 @@
|
||||
//! Conversation CRUD operations and persistence
|
||||
|
||||
use crate::data_connector::{
|
||||
conversation_items::ListParams, conversation_items::SortOrder, Conversation, ConversationId,
|
||||
ConversationItemStorage, ConversationStorage, NewConversation, NewConversationItem, ResponseId,
|
||||
ResponseStorage, SharedConversationItemStorage, SharedConversationStorage,
|
||||
};
|
||||
use crate::protocols::spec::{ResponseInput, ResponsesRequest};
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use axum::Json;
|
||||
use chrono::Utc;
|
||||
use serde_json::{json, Value};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use super::responses::build_stored_response;
|
||||
|
||||
/// Maximum number of properties allowed in conversation metadata
|
||||
pub(crate) const MAX_METADATA_PROPERTIES: usize = 16;
|
||||
|
||||
// ============================================================================
|
||||
// Conversation CRUD Operations
|
||||
// ============================================================================
|
||||
|
||||
/// Create a new conversation
|
||||
pub(super) async fn create_conversation(
|
||||
conversation_storage: &SharedConversationStorage,
|
||||
body: Value,
|
||||
) -> Response {
|
||||
// TODO: The validation should be done in the right place
|
||||
let metadata = match body.get("metadata") {
|
||||
Some(Value::Object(map)) => {
|
||||
if map.len() > MAX_METADATA_PROPERTIES {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({
|
||||
"error": format!(
|
||||
"metadata cannot have more than {} properties",
|
||||
MAX_METADATA_PROPERTIES
|
||||
)
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
Some(map.clone())
|
||||
}
|
||||
Some(_) => {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({"error": "metadata must be an object"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
|
||||
let new_conv = NewConversation { metadata };
|
||||
|
||||
match conversation_storage.create_conversation(new_conv).await {
|
||||
Ok(conversation) => {
|
||||
info!(conversation_id = %conversation.id.0, "Created conversation");
|
||||
(StatusCode::OK, Json(conversation_to_json(&conversation))).into_response()
|
||||
}
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": format!("Failed to create conversation: {}", e)})),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a conversation by ID
|
||||
pub(super) async fn get_conversation(
|
||||
conversation_storage: &SharedConversationStorage,
|
||||
conv_id: &str,
|
||||
) -> Response {
|
||||
let conversation_id = ConversationId::from(conv_id);
|
||||
|
||||
match conversation_storage
|
||||
.get_conversation(&conversation_id)
|
||||
.await
|
||||
{
|
||||
Ok(Some(conversation)) => {
|
||||
(StatusCode::OK, Json(conversation_to_json(&conversation))).into_response()
|
||||
}
|
||||
Ok(None) => (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({"error": "Conversation not found"})),
|
||||
)
|
||||
.into_response(),
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": format!("Failed to get conversation: {}", e)})),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Update a conversation's metadata
|
||||
pub(super) async fn update_conversation(
|
||||
conversation_storage: &SharedConversationStorage,
|
||||
conv_id: &str,
|
||||
body: Value,
|
||||
) -> Response {
|
||||
let conversation_id = ConversationId::from(conv_id);
|
||||
|
||||
let current_meta = match conversation_storage
|
||||
.get_conversation(&conversation_id)
|
||||
.await
|
||||
{
|
||||
Ok(Some(meta)) => meta,
|
||||
Ok(None) => {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({"error": "Conversation not found"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
Err(e) => {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": format!("Failed to get conversation: {}", e)})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
enum Patch {
|
||||
Set(String, Value),
|
||||
Delete(String),
|
||||
}
|
||||
|
||||
let mut patches: Vec<Patch> = Vec::new();
|
||||
|
||||
if let Some(metadata_val) = body.get("metadata") {
|
||||
if let Some(map) = metadata_val.as_object() {
|
||||
for (k, v) in map {
|
||||
if v.is_null() {
|
||||
patches.push(Patch::Delete(k.clone()));
|
||||
} else {
|
||||
patches.push(Patch::Set(k.clone(), v.clone()));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({"error": "metadata must be an object"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
}
|
||||
|
||||
let mut new_metadata = current_meta.metadata.clone().unwrap_or_default();
|
||||
for patch in patches {
|
||||
match patch {
|
||||
Patch::Set(k, v) => {
|
||||
new_metadata.insert(k, v);
|
||||
}
|
||||
Patch::Delete(k) => {
|
||||
new_metadata.remove(&k);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if new_metadata.len() > MAX_METADATA_PROPERTIES {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({
|
||||
"error": format!(
|
||||
"metadata cannot have more than {} properties",
|
||||
MAX_METADATA_PROPERTIES
|
||||
)
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
let final_metadata = if new_metadata.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(new_metadata)
|
||||
};
|
||||
|
||||
match conversation_storage
|
||||
.update_conversation(&conversation_id, final_metadata)
|
||||
.await
|
||||
{
|
||||
Ok(Some(conversation)) => {
|
||||
info!(conversation_id = %conversation_id.0, "Updated conversation");
|
||||
(StatusCode::OK, Json(conversation_to_json(&conversation))).into_response()
|
||||
}
|
||||
Ok(None) => (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({"error": "Conversation not found"})),
|
||||
)
|
||||
.into_response(),
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": format!("Failed to update conversation: {}", e)})),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Delete a conversation
|
||||
pub(super) async fn delete_conversation(
|
||||
conversation_storage: &SharedConversationStorage,
|
||||
conv_id: &str,
|
||||
) -> Response {
|
||||
let conversation_id = ConversationId::from(conv_id);
|
||||
|
||||
match conversation_storage
|
||||
.get_conversation(&conversation_id)
|
||||
.await
|
||||
{
|
||||
Ok(Some(_)) => {}
|
||||
Ok(None) => {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({"error": "Conversation not found"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
Err(e) => {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": format!("Failed to get conversation: {}", e)})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
}
|
||||
|
||||
match conversation_storage
|
||||
.delete_conversation(&conversation_id)
|
||||
.await
|
||||
{
|
||||
Ok(_) => {
|
||||
info!(conversation_id = %conversation_id.0, "Deleted conversation");
|
||||
(
|
||||
StatusCode::OK,
|
||||
Json(json!({
|
||||
"id": conversation_id.0,
|
||||
"object": "conversation.deleted",
|
||||
"deleted": true
|
||||
})),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": format!("Failed to delete conversation: {}", e)})),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
/// List items in a conversation with pagination
|
||||
pub(super) async fn list_conversation_items(
|
||||
conversation_storage: &SharedConversationStorage,
|
||||
item_storage: &SharedConversationItemStorage,
|
||||
conv_id: &str,
|
||||
query_params: HashMap<String, String>,
|
||||
) -> Response {
|
||||
let conversation_id = ConversationId::from(conv_id);
|
||||
|
||||
match conversation_storage
|
||||
.get_conversation(&conversation_id)
|
||||
.await
|
||||
{
|
||||
Ok(Some(_)) => {}
|
||||
Ok(None) => {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({"error": "Conversation not found"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
Err(e) => {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": format!("Failed to get conversation: {}", e)})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
}
|
||||
|
||||
let limit: usize = query_params
|
||||
.get("limit")
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(100);
|
||||
|
||||
let after = query_params.get("after").map(|s| s.to_string());
|
||||
|
||||
// Default to descending order (most recent first)
|
||||
let order = query_params
|
||||
.get("order")
|
||||
.and_then(|s| match s.as_str() {
|
||||
"asc" => Some(SortOrder::Asc),
|
||||
"desc" => Some(SortOrder::Desc),
|
||||
_ => None,
|
||||
})
|
||||
.unwrap_or(SortOrder::Desc);
|
||||
|
||||
let params = ListParams {
|
||||
limit,
|
||||
order,
|
||||
after,
|
||||
};
|
||||
|
||||
match item_storage.list_items(&conversation_id, params).await {
|
||||
Ok(items) => {
|
||||
let item_values: Vec<Value> = items
|
||||
.iter()
|
||||
.map(|item| {
|
||||
let mut obj = serde_json::Map::new();
|
||||
obj.insert("id".to_string(), json!(item.id.0));
|
||||
obj.insert("type".to_string(), json!(item.item_type));
|
||||
obj.insert("created_at".to_string(), json!(item.created_at));
|
||||
|
||||
obj.insert("content".to_string(), item.content.clone());
|
||||
if let Some(status) = &item.status {
|
||||
obj.insert("status".to_string(), json!(status));
|
||||
}
|
||||
|
||||
Value::Object(obj)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let has_more = items.len() == limit;
|
||||
let last_id = items.last().map(|item| item.id.0.clone());
|
||||
|
||||
(
|
||||
StatusCode::OK,
|
||||
Json(json!({
|
||||
"object": "list",
|
||||
"data": item_values,
|
||||
"has_more": has_more,
|
||||
"first_id": items.first().map(|item| &item.id.0),
|
||||
"last_id": last_id,
|
||||
})),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": format!("Failed to list items: {}", e)})),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Persistence Operations
|
||||
// ============================================================================
|
||||
|
||||
/// Persist conversation items (delegates to persist_items_with_storages)
|
||||
pub(super) async fn persist_conversation_items(
|
||||
conversation_storage: Arc<dyn ConversationStorage>,
|
||||
item_storage: Arc<dyn ConversationItemStorage>,
|
||||
response_storage: Arc<dyn ResponseStorage>,
|
||||
response_json: &Value,
|
||||
original_body: &ResponsesRequest,
|
||||
) -> Result<(), String> {
|
||||
persist_items_with_storages(
|
||||
conversation_storage,
|
||||
item_storage,
|
||||
response_storage,
|
||||
response_json,
|
||||
original_body,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Helper function to create and link a conversation item (two-step API)
|
||||
async fn create_and_link_item(
|
||||
item_storage: &Arc<dyn ConversationItemStorage>,
|
||||
conv_id: &ConversationId,
|
||||
mut new_item: NewConversationItem,
|
||||
) -> Result<(), String> {
|
||||
// Set default status if not provided
|
||||
if new_item.status.is_none() {
|
||||
new_item.status = Some("completed".to_string());
|
||||
}
|
||||
|
||||
// Step 1: Create the item
|
||||
let created = item_storage
|
||||
.create_item(new_item)
|
||||
.await
|
||||
.map_err(|e| format!("Failed to create item: {}", e))?;
|
||||
|
||||
// Step 2: Link it to the conversation
|
||||
item_storage
|
||||
.link_item(conv_id, &created.id, Utc::now())
|
||||
.await
|
||||
.map_err(|e| format!("Failed to link item: {}", e))?;
|
||||
|
||||
info!(
|
||||
conversation_id = %conv_id.0,
|
||||
item_id = %created.id.0,
|
||||
item_type = %created.item_type,
|
||||
"Persisted conversation item and link"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Persist conversation items with all storages
|
||||
async fn persist_items_with_storages(
|
||||
conversation_storage: Arc<dyn ConversationStorage>,
|
||||
item_storage: Arc<dyn ConversationItemStorage>,
|
||||
response_storage: Arc<dyn ResponseStorage>,
|
||||
response_json: &Value,
|
||||
original_body: &ResponsesRequest,
|
||||
) -> Result<(), String> {
|
||||
let conv_id = match &original_body.conversation {
|
||||
Some(id) => ConversationId::from(id.as_str()),
|
||||
None => return Ok(()),
|
||||
};
|
||||
|
||||
if conversation_storage
|
||||
.get_conversation(&conv_id)
|
||||
.await
|
||||
.map_err(|e| format!("Failed to get conversation: {}", e))?
|
||||
.is_none()
|
||||
{
|
||||
warn!(conversation_id = %conv_id.0, "Conversation not found, skipping item persistence");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let response_id_str = response_json
|
||||
.get("id")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| "Response missing id field".to_string())?;
|
||||
let response_id = ResponseId::from(response_id_str);
|
||||
|
||||
let response_id_opt = Some(response_id_str.to_string());
|
||||
|
||||
// Persist input items
|
||||
match &original_body.input {
|
||||
ResponseInput::Text(text) => {
|
||||
let new_item = NewConversationItem {
|
||||
id: None, // Let storage generate ID
|
||||
response_id: response_id_opt.clone(),
|
||||
item_type: "message".to_string(),
|
||||
role: Some("user".to_string()),
|
||||
content: json!([{ "type": "input_text", "text": text }]),
|
||||
status: Some("completed".to_string()),
|
||||
};
|
||||
create_and_link_item(&item_storage, &conv_id, new_item).await?;
|
||||
}
|
||||
ResponseInput::Items(items_array) => {
|
||||
for input_item in items_array {
|
||||
match input_item {
|
||||
crate::protocols::spec::ResponseInputOutputItem::Message {
|
||||
role,
|
||||
content,
|
||||
status,
|
||||
..
|
||||
} => {
|
||||
let content_v = serde_json::to_value(content)
|
||||
.map_err(|e| format!("Failed to serialize content: {}", e))?;
|
||||
let new_item = NewConversationItem {
|
||||
id: None,
|
||||
response_id: response_id_opt.clone(),
|
||||
item_type: "message".to_string(),
|
||||
role: Some(role.clone()),
|
||||
content: content_v,
|
||||
status: status.clone(),
|
||||
};
|
||||
create_and_link_item(&item_storage, &conv_id, new_item).await?;
|
||||
}
|
||||
_ => {
|
||||
// For other types (FunctionToolCall, etc.), serialize the whole item
|
||||
let item_val = serde_json::to_value(input_item)
|
||||
.map_err(|e| format!("Failed to serialize item: {}", e))?;
|
||||
let new_item = NewConversationItem {
|
||||
id: None,
|
||||
response_id: response_id_opt.clone(),
|
||||
item_type: "unknown".to_string(),
|
||||
role: None,
|
||||
content: item_val,
|
||||
status: Some("completed".to_string()),
|
||||
};
|
||||
create_and_link_item(&item_storage, &conv_id, new_item).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Persist output items
|
||||
if let Some(output_arr) = response_json.get("output").and_then(|v| v.as_array()) {
|
||||
for output_item in output_arr {
|
||||
if let Some(obj) = output_item.as_object() {
|
||||
let item_type = obj
|
||||
.get("type")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("message");
|
||||
|
||||
let role = obj.get("role").and_then(|v| v.as_str()).map(String::from);
|
||||
let status = obj.get("status").and_then(|v| v.as_str()).map(String::from);
|
||||
|
||||
let content = if item_type == "message" {
|
||||
obj.get("content").cloned().unwrap_or(json!([]))
|
||||
} else if item_type == "function_call" || item_type == "function_tool_call" {
|
||||
json!({
|
||||
"type": "function_call",
|
||||
"name": obj.get("name"),
|
||||
"call_id": obj.get("call_id").or_else(|| obj.get("id")),
|
||||
"arguments": obj.get("arguments")
|
||||
})
|
||||
} else if item_type == "function_call_output" {
|
||||
json!({
|
||||
"type": "function_call_output",
|
||||
"call_id": obj.get("call_id"),
|
||||
"output": obj.get("output")
|
||||
})
|
||||
} else {
|
||||
output_item.clone()
|
||||
};
|
||||
|
||||
let new_item = NewConversationItem {
|
||||
id: None,
|
||||
response_id: response_id_opt.clone(),
|
||||
item_type: item_type.to_string(),
|
||||
role,
|
||||
content,
|
||||
status,
|
||||
};
|
||||
create_and_link_item(&item_storage, &conv_id, new_item).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Store the full response using the shared helper
|
||||
let mut stored_response = build_stored_response(response_json, original_body);
|
||||
stored_response.id = response_id;
|
||||
let final_response_id = stored_response.id.clone();
|
||||
|
||||
response_storage
|
||||
.store_response(stored_response)
|
||||
.await
|
||||
.map_err(|e| format!("Failed to store response in conversation: {}", e))?;
|
||||
|
||||
info!(conversation_id = %conv_id.0, response_id = %final_response_id.0, "Persisted conversation items and response");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Helper Functions
|
||||
// ============================================================================
|
||||
|
||||
/// Convert conversation to JSON response
|
||||
fn conversation_to_json(conversation: &Conversation) -> Value {
|
||||
let mut response = json!({
|
||||
"id": conversation.id.0,
|
||||
"object": "conversation",
|
||||
"created_at": conversation.created_at.timestamp()
|
||||
});
|
||||
|
||||
if let Some(metadata) = &conversation.metadata {
|
||||
if !metadata.is_empty() {
|
||||
if let Some(obj) = response.as_object_mut() {
|
||||
obj.insert("metadata".to_string(), Value::Object(metadata.clone()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
response
|
||||
}
|
||||
967
sgl-router/src/routers/openai/mcp.rs
Normal file
967
sgl-router/src/routers/openai/mcp.rs
Normal file
@@ -0,0 +1,967 @@
|
||||
//! MCP (Model Context Protocol) Integration Module
|
||||
//!
|
||||
//! This module contains all MCP-related functionality for the OpenAI router:
|
||||
//! - Tool loop state management for multi-turn tool calling
|
||||
//! - MCP tool execution and result handling
|
||||
//! - Output item builders for MCP-specific response formats
|
||||
//! - SSE event generation for streaming MCP operations
|
||||
//! - Payload transformation for MCP tool interception
|
||||
//! - Metadata injection for MCP operations
|
||||
|
||||
use crate::mcp::McpClientManager;
|
||||
use crate::protocols::spec::{ResponseInput, ResponseToolType, ResponsesRequest};
|
||||
use crate::routers::header_utils::apply_request_headers;
|
||||
use axum::http::HeaderMap;
|
||||
use bytes::Bytes;
|
||||
use serde_json::{json, to_value, Value};
|
||||
use std::{io, sync::Arc};
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use super::utils::event_types;
|
||||
|
||||
// ============================================================================
|
||||
// Configuration and State Types
|
||||
// ============================================================================
|
||||
|
||||
/// Configuration for MCP tool calling loops
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct McpLoopConfig {
|
||||
/// Maximum iterations as safety limit (internal only, default: 10)
|
||||
/// Prevents infinite loops when max_tool_calls is not set
|
||||
pub max_iterations: usize,
|
||||
}
|
||||
|
||||
impl Default for McpLoopConfig {
|
||||
fn default() -> Self {
|
||||
Self { max_iterations: 10 }
|
||||
}
|
||||
}
|
||||
|
||||
/// State for tracking multi-turn tool calling loop
|
||||
pub(crate) struct ToolLoopState {
|
||||
/// Current iteration number (starts at 0, increments with each tool call)
|
||||
pub iteration: usize,
|
||||
/// Total number of tool calls executed
|
||||
pub total_calls: usize,
|
||||
/// Conversation history (function_call and function_call_output items)
|
||||
pub conversation_history: Vec<Value>,
|
||||
/// Original user input (preserved for building resume payloads)
|
||||
pub original_input: ResponseInput,
|
||||
}
|
||||
|
||||
impl ToolLoopState {
|
||||
pub fn new(original_input: ResponseInput) -> Self {
|
||||
Self {
|
||||
iteration: 0,
|
||||
total_calls: 0,
|
||||
conversation_history: Vec::new(),
|
||||
original_input,
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a tool call in the loop state
|
||||
pub fn record_call(
|
||||
&mut self,
|
||||
call_id: String,
|
||||
tool_name: String,
|
||||
args_json_str: String,
|
||||
output_str: String,
|
||||
) {
|
||||
// Add function_call item to history
|
||||
let func_item = json!({
|
||||
"type": event_types::ITEM_TYPE_FUNCTION_CALL,
|
||||
"call_id": call_id,
|
||||
"name": tool_name,
|
||||
"arguments": args_json_str
|
||||
});
|
||||
self.conversation_history.push(func_item);
|
||||
|
||||
// Add function_call_output item to history
|
||||
let output_item = json!({
|
||||
"type": "function_call_output",
|
||||
"call_id": call_id,
|
||||
"output": output_str
|
||||
});
|
||||
self.conversation_history.push(output_item);
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a function call being accumulated across delta events
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct FunctionCallInProgress {
|
||||
pub call_id: String,
|
||||
pub name: String,
|
||||
pub arguments_buffer: String,
|
||||
pub output_index: usize,
|
||||
pub last_obfuscation: Option<String>,
|
||||
pub assigned_output_index: Option<usize>,
|
||||
}
|
||||
|
||||
impl FunctionCallInProgress {
|
||||
pub fn new(call_id: String, output_index: usize) -> Self {
|
||||
Self {
|
||||
call_id,
|
||||
name: String::new(),
|
||||
arguments_buffer: String::new(),
|
||||
output_index,
|
||||
last_obfuscation: None,
|
||||
assigned_output_index: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_complete(&self) -> bool {
|
||||
// A tool call is complete if it has a name
|
||||
!self.name.is_empty()
|
||||
}
|
||||
|
||||
pub fn effective_output_index(&self) -> usize {
|
||||
self.assigned_output_index.unwrap_or(self.output_index)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// MCP Manager Integration
|
||||
// ============================================================================
|
||||
|
||||
/// Build a request-scoped MCP manager from request tools, if present.
|
||||
pub(super) async fn mcp_manager_from_request_tools(
|
||||
tools: &[crate::protocols::spec::ResponseTool],
|
||||
) -> Option<Arc<McpClientManager>> {
|
||||
let tool = tools
|
||||
.iter()
|
||||
.find(|t| matches!(t.r#type, ResponseToolType::Mcp) && t.server_url.is_some())?;
|
||||
let server_url = tool.server_url.as_ref()?.trim().to_string();
|
||||
if !(server_url.starts_with("http://") || server_url.starts_with("https://")) {
|
||||
warn!(
|
||||
"Ignoring MCP server_url with unsupported scheme: {}",
|
||||
server_url
|
||||
);
|
||||
return None;
|
||||
}
|
||||
let name = tool
|
||||
.server_label
|
||||
.clone()
|
||||
.unwrap_or_else(|| "request-mcp".to_string());
|
||||
let token = tool.authorization.clone();
|
||||
let transport = if server_url.contains("/sse") {
|
||||
crate::mcp::McpTransport::Sse {
|
||||
url: server_url,
|
||||
token,
|
||||
}
|
||||
} else {
|
||||
crate::mcp::McpTransport::Streamable {
|
||||
url: server_url,
|
||||
token,
|
||||
}
|
||||
};
|
||||
let cfg = crate::mcp::McpConfig {
|
||||
servers: vec![crate::mcp::McpServerConfig { name, transport }],
|
||||
};
|
||||
match McpClientManager::new(cfg).await {
|
||||
Ok(mgr) => Some(Arc::new(mgr)),
|
||||
Err(err) => {
|
||||
warn!("Failed to initialize request-scoped MCP manager: {}", err);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tool Execution
|
||||
// ============================================================================
|
||||
|
||||
/// Execute an MCP tool call
|
||||
pub(super) async fn execute_mcp_call(
|
||||
mcp_mgr: &Arc<McpClientManager>,
|
||||
tool_name: &str,
|
||||
args_json_str: &str,
|
||||
) -> Result<(String, String), String> {
|
||||
let args_value: Value =
|
||||
serde_json::from_str(args_json_str).map_err(|e| format!("parse tool args: {}", e))?;
|
||||
let args_obj = args_value.as_object().cloned();
|
||||
|
||||
let server_name = mcp_mgr
|
||||
.get_tool(tool_name)
|
||||
.map(|t| t.server)
|
||||
.ok_or_else(|| format!("tool not found: {}", tool_name))?;
|
||||
|
||||
let result = mcp_mgr
|
||||
.call_tool(tool_name, args_obj)
|
||||
.await
|
||||
.map_err(|e| format!("tool call failed: {}", e))?;
|
||||
|
||||
let output_str = serde_json::to_string(&result)
|
||||
.map_err(|e| format!("Failed to serialize tool result: {}", e))?;
|
||||
Ok((server_name, output_str))
|
||||
}
|
||||
|
||||
/// Execute detected tool calls and send completion events to client
|
||||
/// Returns false if client disconnected during execution
|
||||
pub(super) async fn execute_streaming_tool_calls(
|
||||
pending_calls: Vec<FunctionCallInProgress>,
|
||||
active_mcp: &Arc<McpClientManager>,
|
||||
tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>,
|
||||
state: &mut ToolLoopState,
|
||||
server_label: &str,
|
||||
sequence_number: &mut u64,
|
||||
) -> bool {
|
||||
// Execute all pending tool calls (sequential, as PR3 is skipped)
|
||||
for call in pending_calls {
|
||||
// Skip if name is empty (invalid call)
|
||||
if call.name.is_empty() {
|
||||
warn!(
|
||||
"Skipping incomplete tool call: name is empty, args_len={}",
|
||||
call.arguments_buffer.len()
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
info!(
|
||||
"Executing tool call during streaming: {} ({})",
|
||||
call.name, call.call_id
|
||||
);
|
||||
|
||||
// Use empty JSON object if arguments_buffer is empty
|
||||
let args_str = if call.arguments_buffer.is_empty() {
|
||||
"{}"
|
||||
} else {
|
||||
&call.arguments_buffer
|
||||
};
|
||||
|
||||
let call_result = execute_mcp_call(active_mcp, &call.name, args_str).await;
|
||||
let (output_str, success, error_msg) = match call_result {
|
||||
Ok((_, output)) => (output, true, None),
|
||||
Err(err) => {
|
||||
warn!("Tool execution failed during streaming: {}", err);
|
||||
(json!({ "error": &err }).to_string(), false, Some(err))
|
||||
}
|
||||
};
|
||||
|
||||
// Send mcp_call completion event to client
|
||||
if !send_mcp_call_completion_events_with_error(
|
||||
tx,
|
||||
&call,
|
||||
&output_str,
|
||||
server_label,
|
||||
success,
|
||||
error_msg.as_deref(),
|
||||
sequence_number,
|
||||
) {
|
||||
// Client disconnected, no point continuing tool execution
|
||||
return false;
|
||||
}
|
||||
|
||||
// Record the call
|
||||
state.record_call(call.call_id, call.name, call.arguments_buffer, output_str);
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Payload Transformation
|
||||
// ============================================================================
|
||||
|
||||
/// Transform payload to replace MCP tools with function tools for streaming
|
||||
pub(super) fn prepare_mcp_payload_for_streaming(
|
||||
payload: &mut Value,
|
||||
active_mcp: &Arc<McpClientManager>,
|
||||
) {
|
||||
if let Some(obj) = payload.as_object_mut() {
|
||||
// Remove any non-function tools from outgoing payload
|
||||
if let Some(v) = obj.get_mut("tools") {
|
||||
if let Some(arr) = v.as_array_mut() {
|
||||
arr.retain(|item| {
|
||||
item.get("type")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s == event_types::ITEM_TYPE_FUNCTION)
|
||||
.unwrap_or(false)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Build function tools for all discovered MCP tools
|
||||
let mut tools_json = Vec::new();
|
||||
let tools = active_mcp.list_tools();
|
||||
for t in tools {
|
||||
let parameters = t.parameters.clone().unwrap_or(serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"additionalProperties": false
|
||||
}));
|
||||
let tool = serde_json::json!({
|
||||
"type": event_types::ITEM_TYPE_FUNCTION,
|
||||
"name": t.name,
|
||||
"description": t.description,
|
||||
"parameters": parameters
|
||||
});
|
||||
tools_json.push(tool);
|
||||
}
|
||||
if !tools_json.is_empty() {
|
||||
obj.insert("tools".to_string(), Value::Array(tools_json));
|
||||
obj.insert("tool_choice".to_string(), Value::String("auto".to_string()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a resume payload with conversation history
|
||||
pub(super) fn build_resume_payload(
|
||||
base_payload: &Value,
|
||||
conversation_history: &[Value],
|
||||
original_input: &ResponseInput,
|
||||
tools_json: &Value,
|
||||
is_streaming: bool,
|
||||
) -> Result<Value, String> {
|
||||
// Clone the base payload which already has cleaned fields
|
||||
let mut payload = base_payload.clone();
|
||||
|
||||
let obj = payload
|
||||
.as_object_mut()
|
||||
.ok_or_else(|| "payload not an object".to_string())?;
|
||||
|
||||
// Build input array: start with original user input
|
||||
let mut input_array = Vec::new();
|
||||
|
||||
// Add original user message
|
||||
// For structured input, serialize the original input items
|
||||
match original_input {
|
||||
ResponseInput::Text(text) => {
|
||||
let user_item = json!({
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [{ "type": "input_text", "text": text }]
|
||||
});
|
||||
input_array.push(user_item);
|
||||
}
|
||||
ResponseInput::Items(items) => {
|
||||
// Items are already structured ResponseInputOutputItem, convert to JSON
|
||||
if let Ok(items_value) = to_value(items) {
|
||||
if let Some(items_arr) = items_value.as_array() {
|
||||
input_array.extend_from_slice(items_arr);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add all conversation history (function calls and outputs)
|
||||
input_array.extend_from_slice(conversation_history);
|
||||
|
||||
obj.insert("input".to_string(), Value::Array(input_array));
|
||||
|
||||
// Use the transformed tools (function tools, not MCP tools)
|
||||
if let Some(tools_arr) = tools_json.as_array() {
|
||||
if !tools_arr.is_empty() {
|
||||
obj.insert("tools".to_string(), tools_json.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Set streaming mode based on caller's context
|
||||
obj.insert("stream".to_string(), Value::Bool(is_streaming));
|
||||
obj.insert("store".to_string(), Value::Bool(false));
|
||||
|
||||
// Note: SGLang-specific fields were already removed from base_payload
|
||||
// before it was passed to execute_tool_loop (see route_responses lines 1935-1946)
|
||||
|
||||
Ok(payload)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// SSE Event Senders
|
||||
// ============================================================================
|
||||
|
||||
/// Send mcp_list_tools events to client at the start of streaming
|
||||
/// Returns false if client disconnected
|
||||
pub(super) fn send_mcp_list_tools_events(
|
||||
tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>,
|
||||
mcp: &Arc<McpClientManager>,
|
||||
server_label: &str,
|
||||
output_index: usize,
|
||||
sequence_number: &mut u64,
|
||||
) -> bool {
|
||||
let tools_item_full = build_mcp_list_tools_item(mcp, server_label);
|
||||
let item_id = tools_item_full
|
||||
.get("id")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
// Create empty tools version for the initial added event
|
||||
let mut tools_item_empty = tools_item_full.clone();
|
||||
if let Some(obj) = tools_item_empty.as_object_mut() {
|
||||
obj.insert("tools".to_string(), json!([]));
|
||||
}
|
||||
|
||||
// Event 1: response.output_item.added with empty tools
|
||||
let event1_payload = json!({
|
||||
"type": event_types::OUTPUT_ITEM_ADDED,
|
||||
"sequence_number": *sequence_number,
|
||||
"output_index": output_index,
|
||||
"item": tools_item_empty
|
||||
});
|
||||
*sequence_number += 1;
|
||||
let event1 = format!(
|
||||
"event: {}\ndata: {}\n\n",
|
||||
event_types::OUTPUT_ITEM_ADDED,
|
||||
event1_payload
|
||||
);
|
||||
if tx.send(Ok(Bytes::from(event1))).is_err() {
|
||||
return false; // Client disconnected
|
||||
}
|
||||
|
||||
// Event 2: response.mcp_list_tools.in_progress
|
||||
let event2_payload = json!({
|
||||
"type": event_types::MCP_LIST_TOOLS_IN_PROGRESS,
|
||||
"sequence_number": *sequence_number,
|
||||
"output_index": output_index,
|
||||
"item_id": item_id
|
||||
});
|
||||
*sequence_number += 1;
|
||||
let event2 = format!(
|
||||
"event: {}\ndata: {}\n\n",
|
||||
event_types::MCP_LIST_TOOLS_IN_PROGRESS,
|
||||
event2_payload
|
||||
);
|
||||
if tx.send(Ok(Bytes::from(event2))).is_err() {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Event 3: response.mcp_list_tools.completed
|
||||
let event3_payload = json!({
|
||||
"type": event_types::MCP_LIST_TOOLS_COMPLETED,
|
||||
"sequence_number": *sequence_number,
|
||||
"output_index": output_index,
|
||||
"item_id": item_id
|
||||
});
|
||||
*sequence_number += 1;
|
||||
let event3 = format!(
|
||||
"event: {}\ndata: {}\n\n",
|
||||
event_types::MCP_LIST_TOOLS_COMPLETED,
|
||||
event3_payload
|
||||
);
|
||||
if tx.send(Ok(Bytes::from(event3))).is_err() {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Event 4: response.output_item.done with full tools list
|
||||
let event4_payload = json!({
|
||||
"type": event_types::OUTPUT_ITEM_DONE,
|
||||
"sequence_number": *sequence_number,
|
||||
"output_index": output_index,
|
||||
"item": tools_item_full
|
||||
});
|
||||
*sequence_number += 1;
|
||||
let event4 = format!(
|
||||
"event: {}\ndata: {}\n\n",
|
||||
event_types::OUTPUT_ITEM_DONE,
|
||||
event4_payload
|
||||
);
|
||||
tx.send(Ok(Bytes::from(event4))).is_ok()
|
||||
}
|
||||
|
||||
/// Send mcp_call completion events after tool execution
|
||||
/// Returns false if client disconnected
|
||||
pub(super) fn send_mcp_call_completion_events_with_error(
|
||||
tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>,
|
||||
call: &FunctionCallInProgress,
|
||||
output: &str,
|
||||
server_label: &str,
|
||||
success: bool,
|
||||
error_msg: Option<&str>,
|
||||
sequence_number: &mut u64,
|
||||
) -> bool {
|
||||
let effective_output_index = call.effective_output_index();
|
||||
|
||||
// Build mcp_call item (reuse existing function)
|
||||
let mcp_call_item = build_mcp_call_item(
|
||||
&call.name,
|
||||
&call.arguments_buffer,
|
||||
output,
|
||||
server_label,
|
||||
success,
|
||||
error_msg,
|
||||
);
|
||||
|
||||
// Get the mcp_call item_id
|
||||
let item_id = mcp_call_item
|
||||
.get("id")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
// Event 1: response.mcp_call.completed
|
||||
let completed_payload = json!({
|
||||
"type": event_types::MCP_CALL_COMPLETED,
|
||||
"sequence_number": *sequence_number,
|
||||
"output_index": effective_output_index,
|
||||
"item_id": item_id
|
||||
});
|
||||
*sequence_number += 1;
|
||||
|
||||
let completed_event = format!(
|
||||
"event: {}\ndata: {}\n\n",
|
||||
event_types::MCP_CALL_COMPLETED,
|
||||
completed_payload
|
||||
);
|
||||
if tx.send(Ok(Bytes::from(completed_event))).is_err() {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Event 2: response.output_item.done (with completed mcp_call)
|
||||
let done_payload = json!({
|
||||
"type": event_types::OUTPUT_ITEM_DONE,
|
||||
"sequence_number": *sequence_number,
|
||||
"output_index": effective_output_index,
|
||||
"item": mcp_call_item
|
||||
});
|
||||
*sequence_number += 1;
|
||||
|
||||
let done_event = format!(
|
||||
"event: {}\ndata: {}\n\n",
|
||||
event_types::OUTPUT_ITEM_DONE,
|
||||
done_payload
|
||||
);
|
||||
tx.send(Ok(Bytes::from(done_event))).is_ok()
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Metadata Injection
|
||||
// ============================================================================
|
||||
|
||||
/// Inject MCP metadata into a streaming response
|
||||
pub(super) fn inject_mcp_metadata_streaming(
|
||||
response: &mut Value,
|
||||
state: &ToolLoopState,
|
||||
mcp: &Arc<McpClientManager>,
|
||||
server_label: &str,
|
||||
) {
|
||||
if let Some(output_array) = response.get_mut("output").and_then(|v| v.as_array_mut()) {
|
||||
output_array.retain(|item| {
|
||||
item.get("type").and_then(|t| t.as_str()) != Some(event_types::ITEM_TYPE_MCP_LIST_TOOLS)
|
||||
});
|
||||
|
||||
let list_tools_item = build_mcp_list_tools_item(mcp, server_label);
|
||||
output_array.insert(0, list_tools_item);
|
||||
|
||||
let mcp_call_items =
|
||||
build_executed_mcp_call_items(&state.conversation_history, server_label);
|
||||
let mut insert_pos = 1;
|
||||
for item in mcp_call_items {
|
||||
output_array.insert(insert_pos, item);
|
||||
insert_pos += 1;
|
||||
}
|
||||
} else if let Some(obj) = response.as_object_mut() {
|
||||
let mut output_items = Vec::new();
|
||||
output_items.push(build_mcp_list_tools_item(mcp, server_label));
|
||||
output_items.extend(build_executed_mcp_call_items(
|
||||
&state.conversation_history,
|
||||
server_label,
|
||||
));
|
||||
obj.insert("output".to_string(), Value::Array(output_items));
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tool Loop Execution
|
||||
// ============================================================================
|
||||
|
||||
/// Execute the tool calling loop
|
||||
pub(super) async fn execute_tool_loop(
|
||||
client: &reqwest::Client,
|
||||
url: &str,
|
||||
headers: Option<&HeaderMap>,
|
||||
initial_payload: Value,
|
||||
original_body: &ResponsesRequest,
|
||||
active_mcp: &Arc<McpClientManager>,
|
||||
config: &McpLoopConfig,
|
||||
) -> Result<Value, String> {
|
||||
let mut state = ToolLoopState::new(original_body.input.clone());
|
||||
|
||||
// Get max_tool_calls from request (None means no user-specified limit)
|
||||
let max_tool_calls = original_body.max_tool_calls.map(|n| n as usize);
|
||||
|
||||
// Keep initial_payload as base template (already has fields cleaned)
|
||||
let base_payload = initial_payload.clone();
|
||||
let tools_json = base_payload.get("tools").cloned().unwrap_or(json!([]));
|
||||
let mut current_payload = initial_payload;
|
||||
|
||||
info!(
|
||||
"Starting tool loop: max_tool_calls={:?}, max_iterations={}",
|
||||
max_tool_calls, config.max_iterations
|
||||
);
|
||||
|
||||
loop {
|
||||
// Make request to upstream
|
||||
let request_builder = client.post(url).json(¤t_payload);
|
||||
let request_builder = if let Some(headers) = headers {
|
||||
apply_request_headers(headers, request_builder, true)
|
||||
} else {
|
||||
request_builder
|
||||
};
|
||||
|
||||
let response = request_builder
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("upstream request failed: {}", e))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let body = response.text().await.unwrap_or_default();
|
||||
return Err(format!("upstream error {}: {}", status, body));
|
||||
}
|
||||
|
||||
let mut response_json = response
|
||||
.json::<Value>()
|
||||
.await
|
||||
.map_err(|e| format!("parse response: {}", e))?;
|
||||
|
||||
// Check for function call
|
||||
if let Some((call_id, tool_name, args_json_str)) = extract_function_call(&response_json) {
|
||||
state.iteration += 1;
|
||||
state.total_calls += 1;
|
||||
|
||||
info!(
|
||||
"Tool loop iteration {}: calling {} (call_id: {})",
|
||||
state.iteration, tool_name, call_id
|
||||
);
|
||||
|
||||
// Check combined limit: use minimum of user's max_tool_calls (if set) and safety max_iterations
|
||||
let effective_limit = match max_tool_calls {
|
||||
Some(user_max) => user_max.min(config.max_iterations),
|
||||
None => config.max_iterations,
|
||||
};
|
||||
|
||||
if state.total_calls > effective_limit {
|
||||
if let Some(user_max) = max_tool_calls {
|
||||
if state.total_calls > user_max {
|
||||
warn!("Reached user-specified max_tool_calls limit: {}", user_max);
|
||||
} else {
|
||||
warn!(
|
||||
"Reached safety max_iterations limit: {}",
|
||||
config.max_iterations
|
||||
);
|
||||
}
|
||||
} else {
|
||||
warn!(
|
||||
"Reached safety max_iterations limit: {}",
|
||||
config.max_iterations
|
||||
);
|
||||
}
|
||||
|
||||
return build_incomplete_response(
|
||||
response_json,
|
||||
state,
|
||||
"max_tool_calls",
|
||||
active_mcp,
|
||||
original_body,
|
||||
);
|
||||
}
|
||||
|
||||
// Execute tool
|
||||
let call_result = execute_mcp_call(active_mcp, &tool_name, &args_json_str).await;
|
||||
|
||||
let output_str = match call_result {
|
||||
Ok((_, output)) => output,
|
||||
Err(err) => {
|
||||
warn!("Tool execution failed: {}", err);
|
||||
// Return error as output, let model decide how to proceed
|
||||
json!({ "error": err }).to_string()
|
||||
}
|
||||
};
|
||||
|
||||
// Record the call
|
||||
state.record_call(call_id, tool_name, args_json_str, output_str);
|
||||
|
||||
// Build resume payload
|
||||
current_payload = build_resume_payload(
|
||||
&base_payload,
|
||||
&state.conversation_history,
|
||||
&state.original_input,
|
||||
&tools_json,
|
||||
false, // is_streaming = false (non-streaming tool loop)
|
||||
)?;
|
||||
} else {
|
||||
// No more tool calls, we're done
|
||||
info!(
|
||||
"Tool loop completed: {} iterations, {} total calls",
|
||||
state.iteration, state.total_calls
|
||||
);
|
||||
|
||||
// Inject MCP output items if we executed any tools
|
||||
if state.total_calls > 0 {
|
||||
let server_label = original_body
|
||||
.tools
|
||||
.iter()
|
||||
.find(|t| matches!(t.r#type, ResponseToolType::Mcp))
|
||||
.and_then(|t| t.server_label.as_deref())
|
||||
.unwrap_or("mcp");
|
||||
|
||||
// Build mcp_list_tools item
|
||||
let list_tools_item = build_mcp_list_tools_item(active_mcp, server_label);
|
||||
|
||||
// Insert at beginning of output array
|
||||
if let Some(output_array) = response_json
|
||||
.get_mut("output")
|
||||
.and_then(|v| v.as_array_mut())
|
||||
{
|
||||
output_array.insert(0, list_tools_item);
|
||||
|
||||
// Build mcp_call items using helper function
|
||||
let mcp_call_items =
|
||||
build_executed_mcp_call_items(&state.conversation_history, server_label);
|
||||
|
||||
// Insert mcp_call items after mcp_list_tools using mutable position
|
||||
let mut insert_pos = 1;
|
||||
for item in mcp_call_items {
|
||||
output_array.insert(insert_pos, item);
|
||||
insert_pos += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Ok(response_json);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Build an incomplete response when limits are exceeded
|
||||
pub(super) fn build_incomplete_response(
|
||||
mut response: Value,
|
||||
state: ToolLoopState,
|
||||
reason: &str,
|
||||
active_mcp: &Arc<McpClientManager>,
|
||||
original_body: &ResponsesRequest,
|
||||
) -> Result<Value, String> {
|
||||
let obj = response
|
||||
.as_object_mut()
|
||||
.ok_or_else(|| "response not an object".to_string())?;
|
||||
|
||||
// Set status to completed (not failed - partial success)
|
||||
obj.insert("status".to_string(), Value::String("completed".to_string()));
|
||||
|
||||
// Set incomplete_details
|
||||
obj.insert(
|
||||
"incomplete_details".to_string(),
|
||||
json!({ "reason": reason }),
|
||||
);
|
||||
|
||||
// Convert any function_call in output to mcp_call format
|
||||
if let Some(output_array) = obj.get_mut("output").and_then(|v| v.as_array_mut()) {
|
||||
let server_label = original_body
|
||||
.tools
|
||||
.iter()
|
||||
.find(|t| matches!(t.r#type, ResponseToolType::Mcp))
|
||||
.and_then(|t| t.server_label.as_deref())
|
||||
.unwrap_or("mcp");
|
||||
|
||||
// Find any function_call items and convert them to mcp_call (incomplete)
|
||||
let mut mcp_call_items = Vec::new();
|
||||
for item in output_array.iter() {
|
||||
let item_type = item.get("type").and_then(|t| t.as_str());
|
||||
if item_type == Some(event_types::ITEM_TYPE_FUNCTION_TOOL_CALL)
|
||||
|| item_type == Some(event_types::ITEM_TYPE_FUNCTION_CALL)
|
||||
{
|
||||
let tool_name = item.get("name").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let args = item
|
||||
.get("arguments")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("{}");
|
||||
|
||||
// Mark as incomplete - not executed
|
||||
let mcp_call_item = build_mcp_call_item(
|
||||
tool_name,
|
||||
args,
|
||||
"", // No output - wasn't executed
|
||||
server_label,
|
||||
false, // Not successful
|
||||
Some("Not executed - response stopped due to limit"),
|
||||
);
|
||||
mcp_call_items.push(mcp_call_item);
|
||||
}
|
||||
}
|
||||
|
||||
// Add mcp_list_tools and executed mcp_call items at the beginning
|
||||
if state.total_calls > 0 || !mcp_call_items.is_empty() {
|
||||
let list_tools_item = build_mcp_list_tools_item(active_mcp, server_label);
|
||||
output_array.insert(0, list_tools_item);
|
||||
|
||||
// Add mcp_call items for executed calls using helper
|
||||
let executed_items =
|
||||
build_executed_mcp_call_items(&state.conversation_history, server_label);
|
||||
|
||||
let mut insert_pos = 1;
|
||||
for item in executed_items {
|
||||
output_array.insert(insert_pos, item);
|
||||
insert_pos += 1;
|
||||
}
|
||||
|
||||
// Add incomplete mcp_call items
|
||||
for item in mcp_call_items {
|
||||
output_array.insert(insert_pos, item);
|
||||
insert_pos += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add warning to metadata
|
||||
if let Some(metadata_val) = obj.get_mut("metadata") {
|
||||
if let Some(metadata_obj) = metadata_val.as_object_mut() {
|
||||
if let Some(mcp_val) = metadata_obj.get_mut("mcp") {
|
||||
if let Some(mcp_obj) = mcp_val.as_object_mut() {
|
||||
mcp_obj.insert(
|
||||
"truncation_warning".to_string(),
|
||||
Value::String(format!(
|
||||
"Loop terminated at {} iterations, {} total calls (reason: {})",
|
||||
state.iteration, state.total_calls, reason
|
||||
)),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Output Item Builders
|
||||
// ============================================================================
|
||||
|
||||
/// Generate a unique ID for MCP output items (similar to OpenAI format)
|
||||
pub(super) fn generate_mcp_id(prefix: &str) -> String {
|
||||
use rand::RngCore;
|
||||
let mut rng = rand::rng();
|
||||
let mut bytes = [0u8; 30];
|
||||
rng.fill_bytes(&mut bytes);
|
||||
let hex_string: String = bytes.iter().map(|b| format!("{:02x}", b)).collect();
|
||||
format!("{}_{}", prefix, hex_string)
|
||||
}
|
||||
|
||||
/// Build an mcp_list_tools output item
|
||||
pub(super) fn build_mcp_list_tools_item(mcp: &Arc<McpClientManager>, server_label: &str) -> Value {
|
||||
let tools = mcp.list_tools();
|
||||
let tools_json: Vec<Value> = tools
|
||||
.iter()
|
||||
.map(|t| {
|
||||
json!({
|
||||
"name": t.name,
|
||||
"description": t.description,
|
||||
"input_schema": t.parameters.clone().unwrap_or_else(|| json!({
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"additionalProperties": false
|
||||
})),
|
||||
"annotations": {
|
||||
"read_only": false
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
json!({
|
||||
"id": generate_mcp_id("mcpl"),
|
||||
"type": event_types::ITEM_TYPE_MCP_LIST_TOOLS,
|
||||
"server_label": server_label,
|
||||
"tools": tools_json
|
||||
})
|
||||
}
|
||||
|
||||
/// Build an mcp_call output item
|
||||
pub(super) fn build_mcp_call_item(
|
||||
tool_name: &str,
|
||||
arguments: &str,
|
||||
output: &str,
|
||||
server_label: &str,
|
||||
success: bool,
|
||||
error: Option<&str>,
|
||||
) -> Value {
|
||||
json!({
|
||||
"id": generate_mcp_id("mcp"),
|
||||
"type": event_types::ITEM_TYPE_MCP_CALL,
|
||||
"status": if success { "completed" } else { "failed" },
|
||||
"approval_request_id": Value::Null,
|
||||
"arguments": arguments,
|
||||
"error": error,
|
||||
"name": tool_name,
|
||||
"output": output,
|
||||
"server_label": server_label
|
||||
})
|
||||
}
|
||||
|
||||
/// Helper function to build mcp_call items from executed tool calls in conversation history
|
||||
pub(super) fn build_executed_mcp_call_items(
|
||||
conversation_history: &[Value],
|
||||
server_label: &str,
|
||||
) -> Vec<Value> {
|
||||
let mut mcp_call_items = Vec::new();
|
||||
|
||||
for item in conversation_history {
|
||||
if item.get("type").and_then(|t| t.as_str()) == Some(event_types::ITEM_TYPE_FUNCTION_CALL) {
|
||||
let call_id = item.get("call_id").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let tool_name = item.get("name").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let args = item
|
||||
.get("arguments")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("{}");
|
||||
|
||||
// Find corresponding output
|
||||
let output_item = conversation_history.iter().find(|o| {
|
||||
o.get("type").and_then(|t| t.as_str()) == Some("function_call_output")
|
||||
&& o.get("call_id").and_then(|c| c.as_str()) == Some(call_id)
|
||||
});
|
||||
|
||||
let output_str = output_item
|
||||
.and_then(|o| o.get("output").and_then(|v| v.as_str()))
|
||||
.unwrap_or("{}");
|
||||
|
||||
// Check if output contains error by parsing JSON
|
||||
let is_error = serde_json::from_str::<Value>(output_str)
|
||||
.map(|v| v.get("error").is_some())
|
||||
.unwrap_or(false);
|
||||
|
||||
let mcp_call_item = build_mcp_call_item(
|
||||
tool_name,
|
||||
args,
|
||||
output_str,
|
||||
server_label,
|
||||
!is_error,
|
||||
if is_error {
|
||||
Some("Tool execution failed")
|
||||
} else {
|
||||
None
|
||||
},
|
||||
);
|
||||
mcp_call_items.push(mcp_call_item);
|
||||
}
|
||||
}
|
||||
|
||||
mcp_call_items
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Helper Functions
|
||||
// ============================================================================
|
||||
|
||||
/// Extract function call from a response
|
||||
pub(super) fn extract_function_call(resp: &Value) -> Option<(String, String, String)> {
|
||||
let output = resp.get("output")?.as_array()?;
|
||||
for item in output {
|
||||
let obj = item.as_object()?;
|
||||
let t = obj.get("type")?.as_str()?;
|
||||
if t == event_types::ITEM_TYPE_FUNCTION_TOOL_CALL
|
||||
|| t == event_types::ITEM_TYPE_FUNCTION_CALL
|
||||
{
|
||||
let call_id = obj
|
||||
.get("call_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string())
|
||||
.or_else(|| {
|
||||
obj.get("id")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string())
|
||||
})?;
|
||||
let name = obj.get("name")?.as_str()?.to_string();
|
||||
let arguments = obj.get("arguments")?.as_str()?.to_string();
|
||||
return Some((call_id, name, arguments));
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
18
sgl-router/src/routers/openai/mod.rs
Normal file
18
sgl-router/src/routers/openai/mod.rs
Normal file
@@ -0,0 +1,18 @@
|
||||
//! OpenAI-compatible router implementation
|
||||
//!
|
||||
//! This module provides OpenAI-compatible API routing with support for:
|
||||
//! - Streaming and non-streaming responses
|
||||
//! - MCP (Model Context Protocol) tool calling
|
||||
//! - Response storage and conversation management
|
||||
//! - Multi-turn tool execution loops
|
||||
//! - SSE (Server-Sent Events) streaming
|
||||
|
||||
mod conversations;
|
||||
mod mcp;
|
||||
mod responses;
|
||||
mod router;
|
||||
mod streaming;
|
||||
mod utils;
|
||||
|
||||
// Re-export the main router type for external use
|
||||
pub use router::OpenAIRouter;
|
||||
368
sgl-router/src/routers/openai/responses.rs
Normal file
368
sgl-router/src/routers/openai/responses.rs
Normal file
@@ -0,0 +1,368 @@
|
||||
//! Response storage, patching, and extraction utilities
|
||||
|
||||
use crate::data_connector::{ResponseId, SharedResponseStorage, StoredResponse};
|
||||
use crate::protocols::spec::{ResponseInput, ResponseToolType, ResponsesRequest};
|
||||
use serde_json::{json, Value};
|
||||
use std::collections::HashMap;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use super::utils::event_types;
|
||||
|
||||
// ============================================================================
|
||||
// Response Storage Operations
|
||||
// ============================================================================
|
||||
|
||||
/// Store a response internally (checks if storage is enabled)
|
||||
pub(super) async fn store_response_internal(
|
||||
response_storage: &SharedResponseStorage,
|
||||
response_json: &Value,
|
||||
original_body: &ResponsesRequest,
|
||||
) -> Result<(), String> {
|
||||
if !original_body.store {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
match store_response_impl(response_storage, response_json, original_body).await {
|
||||
Ok(response_id) => {
|
||||
info!(response_id = %response_id.0, "Stored response locally");
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a StoredResponse from response JSON and original request
|
||||
pub(super) fn build_stored_response(
|
||||
response_json: &Value,
|
||||
original_body: &ResponsesRequest,
|
||||
) -> StoredResponse {
|
||||
let input_text = match &original_body.input {
|
||||
ResponseInput::Text(text) => text.clone(),
|
||||
ResponseInput::Items(_) => "complex input".to_string(),
|
||||
};
|
||||
|
||||
let output_text = extract_primary_output_text(response_json).unwrap_or_default();
|
||||
|
||||
let mut stored_response = StoredResponse::new(input_text, output_text, None);
|
||||
|
||||
stored_response.instructions = response_json
|
||||
.get("instructions")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string())
|
||||
.or_else(|| original_body.instructions.clone());
|
||||
|
||||
stored_response.model = response_json
|
||||
.get("model")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string())
|
||||
.or_else(|| original_body.model.clone());
|
||||
|
||||
stored_response.user = response_json
|
||||
.get("user")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string())
|
||||
.or_else(|| original_body.user.clone());
|
||||
|
||||
// Set conversation id from request if provided
|
||||
if let Some(conv_id) = original_body.conversation.clone() {
|
||||
stored_response.conversation_id = Some(conv_id);
|
||||
}
|
||||
|
||||
stored_response.metadata = response_json
|
||||
.get("metadata")
|
||||
.and_then(|v| v.as_object())
|
||||
.map(|m| {
|
||||
m.iter()
|
||||
.map(|(k, v)| (k.clone(), v.clone()))
|
||||
.collect::<HashMap<_, _>>()
|
||||
})
|
||||
.unwrap_or_else(|| original_body.metadata.clone().unwrap_or_default());
|
||||
|
||||
stored_response.previous_response_id = response_json
|
||||
.get("previous_response_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(ResponseId::from)
|
||||
.or_else(|| {
|
||||
original_body
|
||||
.previous_response_id
|
||||
.as_ref()
|
||||
.map(|id| ResponseId::from(id.as_str()))
|
||||
});
|
||||
|
||||
if let Some(id_str) = response_json.get("id").and_then(|v| v.as_str()) {
|
||||
stored_response.id = ResponseId::from(id_str);
|
||||
}
|
||||
|
||||
stored_response.raw_response = response_json.clone();
|
||||
|
||||
stored_response
|
||||
}
|
||||
|
||||
/// Store response implementation (public for use across modules)
|
||||
pub(super) async fn store_response_impl(
|
||||
response_storage: &SharedResponseStorage,
|
||||
response_json: &Value,
|
||||
original_body: &ResponsesRequest,
|
||||
) -> Result<ResponseId, String> {
|
||||
let stored_response = build_stored_response(response_json, original_body);
|
||||
|
||||
response_storage
|
||||
.store_response(stored_response)
|
||||
.await
|
||||
.map_err(|e| format!("Failed to store response: {}", e))
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Response JSON Patching
|
||||
// ============================================================================
|
||||
|
||||
/// Patch streaming response JSON with metadata from original request
|
||||
pub(super) fn patch_streaming_response_json(
|
||||
response_json: &mut Value,
|
||||
original_body: &ResponsesRequest,
|
||||
original_previous_response_id: Option<&str>,
|
||||
) {
|
||||
if let Some(obj) = response_json.as_object_mut() {
|
||||
if let Some(prev_id) = original_previous_response_id {
|
||||
let should_insert = obj
|
||||
.get("previous_response_id")
|
||||
.map(|v| v.is_null() || v.as_str().map(|s| s.is_empty()).unwrap_or(false))
|
||||
.unwrap_or(true);
|
||||
if should_insert {
|
||||
obj.insert(
|
||||
"previous_response_id".to_string(),
|
||||
Value::String(prev_id.to_string()),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if !obj.contains_key("instructions")
|
||||
|| obj
|
||||
.get("instructions")
|
||||
.map(|v| v.is_null())
|
||||
.unwrap_or(false)
|
||||
{
|
||||
if let Some(instructions) = &original_body.instructions {
|
||||
obj.insert(
|
||||
"instructions".to_string(),
|
||||
Value::String(instructions.clone()),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if !obj.contains_key("metadata")
|
||||
|| obj.get("metadata").map(|v| v.is_null()).unwrap_or(false)
|
||||
{
|
||||
if let Some(metadata) = &original_body.metadata {
|
||||
let metadata_map: serde_json::Map<String, Value> = metadata
|
||||
.iter()
|
||||
.map(|(k, v)| (k.clone(), v.clone()))
|
||||
.collect();
|
||||
obj.insert("metadata".to_string(), Value::Object(metadata_map));
|
||||
}
|
||||
}
|
||||
|
||||
obj.insert("store".to_string(), Value::Bool(original_body.store));
|
||||
|
||||
if obj
|
||||
.get("model")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.is_empty())
|
||||
.unwrap_or(true)
|
||||
{
|
||||
if let Some(model) = &original_body.model {
|
||||
obj.insert("model".to_string(), Value::String(model.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
if obj.get("user").map(|v| v.is_null()).unwrap_or(false) {
|
||||
if let Some(user) = &original_body.user {
|
||||
obj.insert("user".to_string(), Value::String(user.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
// Attach conversation id for client response if present (final aggregated JSON)
|
||||
if let Some(conv_id) = original_body.conversation.clone() {
|
||||
obj.insert("conversation".to_string(), json!({"id": conv_id}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Rewrite streaming SSE block to include metadata from original request
|
||||
pub(super) fn rewrite_streaming_block(
|
||||
block: &str,
|
||||
original_body: &ResponsesRequest,
|
||||
original_previous_response_id: Option<&str>,
|
||||
) -> Option<String> {
|
||||
let trimmed = block.trim();
|
||||
if trimmed.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut data_lines: Vec<String> = Vec::new();
|
||||
|
||||
for line in trimmed.lines() {
|
||||
if line.starts_with("data:") {
|
||||
data_lines.push(line.trim_start_matches("data:").trim_start().to_string());
|
||||
}
|
||||
}
|
||||
|
||||
if data_lines.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let payload = data_lines.join("\n");
|
||||
let mut parsed: Value = match serde_json::from_str(&payload) {
|
||||
Ok(value) => value,
|
||||
Err(err) => {
|
||||
warn!("Failed to parse streaming JSON payload: {}", err);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
let event_type = parsed
|
||||
.get("type")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or_default();
|
||||
|
||||
let should_patch = matches!(
|
||||
event_type,
|
||||
event_types::RESPONSE_CREATED
|
||||
| event_types::RESPONSE_IN_PROGRESS
|
||||
| event_types::RESPONSE_COMPLETED
|
||||
);
|
||||
|
||||
if !should_patch {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut changed = false;
|
||||
if let Some(response_obj) = parsed.get_mut("response").and_then(|v| v.as_object_mut()) {
|
||||
let desired_store = Value::Bool(original_body.store);
|
||||
if response_obj.get("store") != Some(&desired_store) {
|
||||
response_obj.insert("store".to_string(), desired_store);
|
||||
changed = true;
|
||||
}
|
||||
|
||||
if let Some(prev_id) = original_previous_response_id {
|
||||
let needs_previous = response_obj
|
||||
.get("previous_response_id")
|
||||
.map(|v| v.is_null() || v.as_str().map(|s| s.is_empty()).unwrap_or(false))
|
||||
.unwrap_or(true);
|
||||
|
||||
if needs_previous {
|
||||
response_obj.insert(
|
||||
"previous_response_id".to_string(),
|
||||
Value::String(prev_id.to_string()),
|
||||
);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Attach conversation id into streaming event response content with ordering
|
||||
if let Some(conv_id) = original_body.conversation.clone() {
|
||||
response_obj.insert("conversation".to_string(), json!({"id": conv_id}));
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
if !changed {
|
||||
return None;
|
||||
}
|
||||
|
||||
let new_payload = match serde_json::to_string(&parsed) {
|
||||
Ok(json) => json,
|
||||
Err(err) => {
|
||||
warn!("Failed to serialize modified streaming payload: {}", err);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
let mut rebuilt_lines = Vec::new();
|
||||
let mut data_written = false;
|
||||
for line in trimmed.lines() {
|
||||
if line.starts_with("data:") {
|
||||
if !data_written {
|
||||
rebuilt_lines.push(format!("data: {}", new_payload));
|
||||
data_written = true;
|
||||
}
|
||||
} else {
|
||||
rebuilt_lines.push(line.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
if !data_written {
|
||||
rebuilt_lines.push(format!("data: {}", new_payload));
|
||||
}
|
||||
|
||||
Some(rebuilt_lines.join("\n"))
|
||||
}
|
||||
|
||||
/// Mask function tools as MCP tools in response for client
|
||||
pub(super) fn mask_tools_as_mcp(resp: &mut Value, original_body: &ResponsesRequest) {
|
||||
let mcp_tool = original_body
|
||||
.tools
|
||||
.iter()
|
||||
.find(|t| matches!(t.r#type, ResponseToolType::Mcp) && t.server_url.is_some());
|
||||
let Some(t) = mcp_tool else {
|
||||
return;
|
||||
};
|
||||
|
||||
let mut m = serde_json::Map::new();
|
||||
m.insert("type".to_string(), Value::String("mcp".to_string()));
|
||||
if let Some(label) = &t.server_label {
|
||||
m.insert("server_label".to_string(), Value::String(label.clone()));
|
||||
}
|
||||
if let Some(url) = &t.server_url {
|
||||
m.insert("server_url".to_string(), Value::String(url.clone()));
|
||||
}
|
||||
if let Some(desc) = &t.server_description {
|
||||
m.insert(
|
||||
"server_description".to_string(),
|
||||
Value::String(desc.clone()),
|
||||
);
|
||||
}
|
||||
if let Some(req) = &t.require_approval {
|
||||
m.insert("require_approval".to_string(), Value::String(req.clone()));
|
||||
}
|
||||
if let Some(allowed) = &t.allowed_tools {
|
||||
m.insert(
|
||||
"allowed_tools".to_string(),
|
||||
Value::Array(allowed.iter().map(|s| Value::String(s.clone())).collect()),
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(obj) = resp.as_object_mut() {
|
||||
obj.insert("tools".to_string(), Value::Array(vec![Value::Object(m)]));
|
||||
obj.entry("tool_choice")
|
||||
.or_insert(Value::String("auto".to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Output Text Extraction
|
||||
// ============================================================================
|
||||
|
||||
/// Extract primary output text from response JSON
|
||||
pub(super) fn extract_primary_output_text(response_json: &Value) -> Option<String> {
|
||||
if let Some(items) = response_json.get("output").and_then(|v| v.as_array()) {
|
||||
for item in items {
|
||||
if let Some(content) = item.get("content").and_then(|v| v.as_array()) {
|
||||
for part in content {
|
||||
if part
|
||||
.get("type")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|t| t == "output_text")
|
||||
.unwrap_or(false)
|
||||
{
|
||||
if let Some(text) = part.get("text").and_then(|v| v.as_str()) {
|
||||
return Some(text.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
909
sgl-router/src/routers/openai/router.rs
Normal file
909
sgl-router/src/routers/openai/router.rs
Normal file
@@ -0,0 +1,909 @@
|
||||
//! OpenAI router - main coordinator that delegates to specialized modules
|
||||
|
||||
use crate::config::CircuitBreakerConfig;
|
||||
use crate::core::{CircuitBreaker, CircuitBreakerConfig as CoreCircuitBreakerConfig};
|
||||
use crate::data_connector::{
|
||||
conversation_items::ListParams, conversation_items::SortOrder, ConversationId, ResponseId,
|
||||
SharedConversationItemStorage, SharedConversationStorage, SharedResponseStorage,
|
||||
};
|
||||
use crate::protocols::spec::{
|
||||
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
|
||||
ResponseContentPart, ResponseInput, ResponseInputOutputItem, ResponsesGetParams,
|
||||
ResponsesRequest,
|
||||
};
|
||||
use crate::routers::header_utils::apply_request_headers;
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::Request,
|
||||
http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
use futures_util::StreamExt;
|
||||
use serde_json::{json, to_value, Value};
|
||||
use std::{
|
||||
any::Any,
|
||||
sync::{atomic::AtomicBool, Arc},
|
||||
};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tracing::{info, warn};
|
||||
|
||||
// Import from sibling modules
|
||||
use super::conversations::{
|
||||
create_conversation, delete_conversation, get_conversation, list_conversation_items,
|
||||
persist_conversation_items, update_conversation,
|
||||
};
|
||||
use super::mcp::{
|
||||
execute_tool_loop, mcp_manager_from_request_tools, prepare_mcp_payload_for_streaming,
|
||||
McpLoopConfig,
|
||||
};
|
||||
use super::responses::{mask_tools_as_mcp, patch_streaming_response_json, store_response_internal};
|
||||
use super::streaming::handle_streaming_response;
|
||||
|
||||
// ============================================================================
|
||||
// OpenAIRouter Struct
|
||||
// ============================================================================
|
||||
|
||||
/// Router for OpenAI backend
|
||||
pub struct OpenAIRouter {
|
||||
/// HTTP client for upstream OpenAI-compatible API
|
||||
client: reqwest::Client,
|
||||
/// Base URL for identification (no trailing slash)
|
||||
base_url: String,
|
||||
/// Circuit breaker
|
||||
circuit_breaker: CircuitBreaker,
|
||||
/// Health status
|
||||
healthy: AtomicBool,
|
||||
/// Response storage for managing conversation history
|
||||
response_storage: SharedResponseStorage,
|
||||
/// Conversation storage backend
|
||||
conversation_storage: SharedConversationStorage,
|
||||
/// Conversation item storage backend
|
||||
conversation_item_storage: SharedConversationItemStorage,
|
||||
/// Optional MCP manager (enabled via config presence)
|
||||
mcp_manager: Option<Arc<crate::mcp::McpClientManager>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for OpenAIRouter {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("OpenAIRouter")
|
||||
.field("base_url", &self.base_url)
|
||||
.field("healthy", &self.healthy)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl OpenAIRouter {
|
||||
/// Maximum number of conversation items to attach as input when a conversation is provided
|
||||
const MAX_CONVERSATION_HISTORY_ITEMS: usize = 100;
|
||||
|
||||
/// Create a new OpenAI router
|
||||
pub async fn new(
|
||||
base_url: String,
|
||||
circuit_breaker_config: Option<CircuitBreakerConfig>,
|
||||
response_storage: SharedResponseStorage,
|
||||
conversation_storage: SharedConversationStorage,
|
||||
conversation_item_storage: SharedConversationItemStorage,
|
||||
) -> Result<Self, String> {
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(300))
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
|
||||
|
||||
let base_url = base_url.trim_end_matches('/').to_string();
|
||||
|
||||
// Convert circuit breaker config
|
||||
let core_cb_config = circuit_breaker_config
|
||||
.map(|cb| CoreCircuitBreakerConfig {
|
||||
failure_threshold: cb.failure_threshold,
|
||||
success_threshold: cb.success_threshold,
|
||||
timeout_duration: std::time::Duration::from_secs(cb.timeout_duration_secs),
|
||||
window_duration: std::time::Duration::from_secs(cb.window_duration_secs),
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
let circuit_breaker = CircuitBreaker::with_config(core_cb_config);
|
||||
|
||||
// Optional MCP manager activation via env var path (config-driven gate)
|
||||
let mcp_manager = match std::env::var("SGLANG_MCP_CONFIG").ok() {
|
||||
Some(path) if !path.trim().is_empty() => {
|
||||
match crate::mcp::McpConfig::from_file(&path).await {
|
||||
Ok(cfg) => match crate::mcp::McpClientManager::new(cfg).await {
|
||||
Ok(mgr) => Some(Arc::new(mgr)),
|
||||
Err(err) => {
|
||||
warn!("Failed to initialize MCP manager: {}", err);
|
||||
None
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
warn!("Failed to load MCP config from '{}': {}", path, err);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
base_url,
|
||||
circuit_breaker,
|
||||
healthy: AtomicBool::new(true),
|
||||
response_storage,
|
||||
conversation_storage,
|
||||
conversation_item_storage,
|
||||
mcp_manager,
|
||||
})
|
||||
}
|
||||
|
||||
/// Handle non-streaming response with optional MCP tool loop
|
||||
async fn handle_non_streaming_response(
|
||||
&self,
|
||||
url: String,
|
||||
headers: Option<&HeaderMap>,
|
||||
mut payload: Value,
|
||||
original_body: &ResponsesRequest,
|
||||
original_previous_response_id: Option<String>,
|
||||
) -> Response {
|
||||
// Check if MCP is active for this request
|
||||
let req_mcp_manager = mcp_manager_from_request_tools(&original_body.tools).await;
|
||||
let active_mcp = req_mcp_manager.as_ref().or(self.mcp_manager.as_ref());
|
||||
|
||||
let mut response_json: Value;
|
||||
|
||||
// If MCP is active, execute tool loop
|
||||
if let Some(mcp) = active_mcp {
|
||||
let config = McpLoopConfig::default();
|
||||
|
||||
// Transform MCP tools to function tools
|
||||
prepare_mcp_payload_for_streaming(&mut payload, mcp);
|
||||
|
||||
match execute_tool_loop(
|
||||
&self.client,
|
||||
&url,
|
||||
headers,
|
||||
payload,
|
||||
original_body,
|
||||
mcp,
|
||||
&config,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(resp) => response_json = resp,
|
||||
Err(err) => {
|
||||
self.circuit_breaker.record_failure();
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": {"message": err}})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// No MCP - simple request
|
||||
let mut request_builder = self.client.post(&url).json(&payload);
|
||||
if let Some(h) = headers {
|
||||
request_builder = apply_request_headers(h, request_builder, true);
|
||||
}
|
||||
|
||||
let response = match request_builder.send().await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
self.circuit_breaker.record_failure();
|
||||
return (
|
||||
StatusCode::BAD_GATEWAY,
|
||||
format!("Failed to forward request to OpenAI: {}", e),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
if !response.status().is_success() {
|
||||
self.circuit_breaker.record_failure();
|
||||
let status = StatusCode::from_u16(response.status().as_u16())
|
||||
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
let body = response.text().await.unwrap_or_default();
|
||||
return (status, body).into_response();
|
||||
}
|
||||
|
||||
response_json = match response.json::<Value>().await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
self.circuit_breaker.record_failure();
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to parse upstream response: {}", e),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
self.circuit_breaker.record_success();
|
||||
}
|
||||
|
||||
// Patch response with metadata
|
||||
mask_tools_as_mcp(&mut response_json, original_body);
|
||||
patch_streaming_response_json(
|
||||
&mut response_json,
|
||||
original_body,
|
||||
original_previous_response_id.as_deref(),
|
||||
);
|
||||
|
||||
// Persist conversation items if conversation is provided
|
||||
if original_body.conversation.is_some() {
|
||||
if let Err(err) = persist_conversation_items(
|
||||
self.conversation_storage.clone(),
|
||||
self.conversation_item_storage.clone(),
|
||||
self.response_storage.clone(),
|
||||
&response_json,
|
||||
original_body,
|
||||
)
|
||||
.await
|
||||
{
|
||||
warn!("Failed to persist conversation items: {}", err);
|
||||
}
|
||||
} else {
|
||||
// Store response only if no conversation (persist_conversation_items already stores it)
|
||||
if let Err(err) =
|
||||
store_response_internal(&self.response_storage, &response_json, original_body).await
|
||||
{
|
||||
warn!("Failed to store response: {}", err);
|
||||
}
|
||||
}
|
||||
|
||||
(StatusCode::OK, Json(response_json)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// RouterTrait Implementation
|
||||
// ============================================================================
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl crate::routers::RouterTrait for OpenAIRouter {
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
}
|
||||
|
||||
async fn health_generate(&self, _req: Request<Body>) -> Response {
|
||||
// Simple upstream probe: GET {base}/v1/models without auth
|
||||
let url = format!("{}/v1/models", self.base_url);
|
||||
match self
|
||||
.client
|
||||
.get(&url)
|
||||
.timeout(std::time::Duration::from_secs(2))
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(resp) => {
|
||||
let code = resp.status();
|
||||
// Treat success and auth-required as healthy (endpoint reachable)
|
||||
if code.is_success() || code.as_u16() == 401 || code.as_u16() == 403 {
|
||||
(StatusCode::OK, "OK").into_response()
|
||||
} else {
|
||||
(
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
format!("Upstream status: {}", code),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
Err(e) => (
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
format!("Upstream error: {}", e),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_server_info(&self, _req: Request<Body>) -> Response {
|
||||
let info = json!({
|
||||
"router_type": "openai",
|
||||
"workers": 1,
|
||||
"base_url": &self.base_url
|
||||
});
|
||||
(StatusCode::OK, info.to_string()).into_response()
|
||||
}
|
||||
|
||||
async fn get_models(&self, req: Request<Body>) -> Response {
|
||||
// Proxy to upstream /v1/models; forward Authorization header if provided
|
||||
let headers = req.headers();
|
||||
|
||||
let mut upstream = self.client.get(format!("{}/v1/models", self.base_url));
|
||||
|
||||
if let Some(auth) = headers
|
||||
.get("authorization")
|
||||
.or_else(|| headers.get("Authorization"))
|
||||
{
|
||||
upstream = upstream.header("Authorization", auth);
|
||||
}
|
||||
|
||||
match upstream.send().await {
|
||||
Ok(res) => {
|
||||
let status = StatusCode::from_u16(res.status().as_u16())
|
||||
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
let content_type = res.headers().get(CONTENT_TYPE).cloned();
|
||||
match res.bytes().await {
|
||||
Ok(body) => {
|
||||
let mut response = Response::new(Body::from(body));
|
||||
*response.status_mut() = status;
|
||||
if let Some(ct) = content_type {
|
||||
response.headers_mut().insert(CONTENT_TYPE, ct);
|
||||
}
|
||||
response
|
||||
}
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to read upstream response: {}", e),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
Err(e) => (
|
||||
StatusCode::BAD_GATEWAY,
|
||||
format!("Failed to contact upstream: {}", e),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_model_info(&self, _req: Request<Body>) -> Response {
|
||||
// Not directly supported without model param; return 501
|
||||
(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"get_model_info not implemented for OpenAI router",
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn route_generate(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &GenerateRequest,
|
||||
_model_id: Option<&str>,
|
||||
) -> Response {
|
||||
// Generate endpoint is SGLang-specific, not supported for OpenAI backend
|
||||
(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"Generate endpoint not supported for OpenAI backend",
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn route_chat(
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &ChatCompletionRequest,
|
||||
_model_id: Option<&str>,
|
||||
) -> Response {
|
||||
if !self.circuit_breaker.can_execute() {
|
||||
return (StatusCode::SERVICE_UNAVAILABLE, "Circuit breaker open").into_response();
|
||||
}
|
||||
|
||||
// Serialize request body, removing SGLang-only fields
|
||||
let mut payload = match to_value(body) {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
format!("Failed to serialize request: {}", e),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
if let Some(obj) = payload.as_object_mut() {
|
||||
for key in [
|
||||
"top_k",
|
||||
"min_p",
|
||||
"min_tokens",
|
||||
"regex",
|
||||
"ebnf",
|
||||
"stop_token_ids",
|
||||
"no_stop_trim",
|
||||
"ignore_eos",
|
||||
"continue_final_message",
|
||||
"skip_special_tokens",
|
||||
"lora_path",
|
||||
"session_params",
|
||||
"separate_reasoning",
|
||||
"stream_reasoning",
|
||||
"chat_template_kwargs",
|
||||
"return_hidden_states",
|
||||
"repetition_penalty",
|
||||
"sampling_seed",
|
||||
] {
|
||||
obj.remove(key);
|
||||
}
|
||||
}
|
||||
|
||||
let url = format!("{}/v1/chat/completions", self.base_url);
|
||||
let mut req = self.client.post(&url).json(&payload);
|
||||
|
||||
// Forward Authorization header if provided
|
||||
if let Some(h) = headers {
|
||||
if let Some(auth) = h.get("authorization").or_else(|| h.get("Authorization")) {
|
||||
req = req.header("Authorization", auth);
|
||||
}
|
||||
}
|
||||
|
||||
// Accept SSE when stream=true
|
||||
if body.stream {
|
||||
req = req.header("Accept", "text/event-stream");
|
||||
}
|
||||
|
||||
let resp = match req.send().await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
self.circuit_breaker.record_failure();
|
||||
return (
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
format!("Failed to contact upstream: {}", e),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let status = StatusCode::from_u16(resp.status().as_u16())
|
||||
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
|
||||
if !body.stream {
|
||||
// Capture Content-Type before consuming response body
|
||||
let content_type = resp.headers().get(CONTENT_TYPE).cloned();
|
||||
match resp.bytes().await {
|
||||
Ok(body) => {
|
||||
self.circuit_breaker.record_success();
|
||||
let mut response = Response::new(Body::from(body));
|
||||
*response.status_mut() = status;
|
||||
if let Some(ct) = content_type {
|
||||
response.headers_mut().insert(CONTENT_TYPE, ct);
|
||||
}
|
||||
response
|
||||
}
|
||||
Err(e) => {
|
||||
self.circuit_breaker.record_failure();
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to read response: {}", e),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Stream SSE bytes to client
|
||||
let stream = resp.bytes_stream();
|
||||
let (tx, rx) = mpsc::unbounded_channel();
|
||||
tokio::spawn(async move {
|
||||
let mut s = stream;
|
||||
while let Some(chunk) = s.next().await {
|
||||
match chunk {
|
||||
Ok(bytes) => {
|
||||
if tx.send(Ok(bytes)).is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
let _ = tx.send(Err(format!("Stream error: {}", e)));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
let mut response = Response::new(Body::from_stream(UnboundedReceiverStream::new(rx)));
|
||||
*response.status_mut() = status;
|
||||
response
|
||||
.headers_mut()
|
||||
.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
|
||||
response
|
||||
}
|
||||
}
|
||||
|
||||
async fn route_completion(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &CompletionRequest,
|
||||
_model_id: Option<&str>,
|
||||
) -> Response {
|
||||
// Completion endpoint not implemented for OpenAI backend
|
||||
(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"Completion endpoint not implemented for OpenAI backend",
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn route_responses(
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &ResponsesRequest,
|
||||
model_id: Option<&str>,
|
||||
) -> Response {
|
||||
let url = format!("{}/v1/responses", self.base_url);
|
||||
|
||||
info!(
|
||||
requested_store = body.store,
|
||||
is_streaming = body.stream,
|
||||
"openai_responses_request"
|
||||
);
|
||||
|
||||
// Validate mutually exclusive params: previous_response_id and conversation
|
||||
// TODO: this validation logic should move the right place, also we need a proper error message module
|
||||
if body.previous_response_id.is_some() && body.conversation.is_some() {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({
|
||||
"error": {
|
||||
"message": "Mutually exclusive parameters. Ensure you are only providing one of: 'previous_response_id' or 'conversation'.",
|
||||
"type": "invalid_request_error",
|
||||
"param": Value::Null,
|
||||
"code": "mutually_exclusive_parameters"
|
||||
}
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
// Clone the body and override model if needed
|
||||
let mut request_body = body.clone();
|
||||
if let Some(model) = model_id {
|
||||
request_body.model = Some(model.to_string());
|
||||
}
|
||||
// Do not forward conversation field upstream; retain for local persistence only
|
||||
request_body.conversation = None;
|
||||
|
||||
// Store the original previous_response_id for the response
|
||||
let original_previous_response_id = request_body.previous_response_id.clone();
|
||||
|
||||
// Handle previous_response_id by loading prior context
|
||||
let mut conversation_items: Option<Vec<ResponseInputOutputItem>> = None;
|
||||
if let Some(prev_id_str) = request_body.previous_response_id.clone() {
|
||||
let prev_id = ResponseId::from(prev_id_str.as_str());
|
||||
match self
|
||||
.response_storage
|
||||
.get_response_chain(&prev_id, None)
|
||||
.await
|
||||
{
|
||||
Ok(chain) => {
|
||||
let mut items = Vec::new();
|
||||
for stored in chain.responses.iter() {
|
||||
// Convert input to conversation item
|
||||
items.push(ResponseInputOutputItem::Message {
|
||||
id: format!("msg_u_{}", stored.id.0.trim_start_matches("resp_")),
|
||||
role: "user".to_string(),
|
||||
content: vec![ResponseContentPart::InputText {
|
||||
text: stored.input.clone(),
|
||||
}],
|
||||
status: Some("completed".to_string()),
|
||||
});
|
||||
|
||||
// Convert output to conversation items directly from stored response
|
||||
if let Some(output_arr) =
|
||||
stored.raw_response.get("output").and_then(|v| v.as_array())
|
||||
{
|
||||
for item in output_arr {
|
||||
if let Ok(output_item) =
|
||||
serde_json::from_value::<ResponseInputOutputItem>(item.clone())
|
||||
{
|
||||
items.push(output_item);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
conversation_items = Some(items);
|
||||
request_body.previous_response_id = None;
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"Failed to load previous response chain for {}: {}",
|
||||
prev_id_str, e
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle conversation by loading history
|
||||
if let Some(conv_id_str) = body.conversation.clone() {
|
||||
let conv_id = ConversationId::from(conv_id_str.as_str());
|
||||
|
||||
// Verify conversation exists
|
||||
if let Ok(None) = self.conversation_storage.get_conversation(&conv_id).await {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({"error": "Conversation not found"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
// Load conversation history (ascending order for chronological context)
|
||||
let params = ListParams {
|
||||
limit: Self::MAX_CONVERSATION_HISTORY_ITEMS,
|
||||
order: SortOrder::Asc,
|
||||
after: None,
|
||||
};
|
||||
|
||||
match self
|
||||
.conversation_item_storage
|
||||
.list_items(&conv_id, params)
|
||||
.await
|
||||
{
|
||||
Ok(stored_items) => {
|
||||
let mut items: Vec<ResponseInputOutputItem> = Vec::new();
|
||||
for item in stored_items.into_iter() {
|
||||
// Only use message items for conversation context
|
||||
// Skip non-message items (reasoning, function calls, etc.)
|
||||
if item.item_type == "message" {
|
||||
if let Ok(content_parts) =
|
||||
serde_json::from_value::<Vec<ResponseContentPart>>(
|
||||
item.content.clone(),
|
||||
)
|
||||
{
|
||||
items.push(ResponseInputOutputItem::Message {
|
||||
id: item.id.0.clone(),
|
||||
role: item.role.clone().unwrap_or_else(|| "user".to_string()),
|
||||
content: content_parts,
|
||||
status: item.status.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Append current request
|
||||
match &request_body.input {
|
||||
ResponseInput::Text(text) => {
|
||||
items.push(ResponseInputOutputItem::Message {
|
||||
id: format!("msg_u_{}", conv_id.0),
|
||||
role: "user".to_string(),
|
||||
content: vec![ResponseContentPart::InputText {
|
||||
text: text.clone(),
|
||||
}],
|
||||
status: Some("completed".to_string()),
|
||||
});
|
||||
}
|
||||
ResponseInput::Items(current_items) => {
|
||||
items.extend_from_slice(current_items);
|
||||
}
|
||||
}
|
||||
|
||||
request_body.input = ResponseInput::Items(items);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to load conversation history: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we have conversation_items from previous_response_id, use them
|
||||
if let Some(mut items) = conversation_items {
|
||||
// Append current request
|
||||
match &request_body.input {
|
||||
ResponseInput::Text(text) => {
|
||||
items.push(ResponseInputOutputItem::Message {
|
||||
id: format!(
|
||||
"msg_u_{}",
|
||||
original_previous_response_id
|
||||
.as_ref()
|
||||
.unwrap_or(&"new".to_string())
|
||||
),
|
||||
role: "user".to_string(),
|
||||
content: vec![ResponseContentPart::InputText { text: text.clone() }],
|
||||
status: Some("completed".to_string()),
|
||||
});
|
||||
}
|
||||
ResponseInput::Items(current_items) => {
|
||||
items.extend_from_slice(current_items);
|
||||
}
|
||||
}
|
||||
|
||||
request_body.input = ResponseInput::Items(items);
|
||||
}
|
||||
|
||||
// Always set store=false for upstream (we store internally)
|
||||
request_body.store = false;
|
||||
|
||||
// Convert to JSON and strip SGLang-specific fields
|
||||
let mut payload = match to_value(&request_body) {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
format!("Failed to serialize request: {}", e),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
// Remove SGLang-specific fields
|
||||
if let Some(obj) = payload.as_object_mut() {
|
||||
for key in [
|
||||
"request_id",
|
||||
"priority",
|
||||
"top_k",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
"min_p",
|
||||
"min_tokens",
|
||||
"regex",
|
||||
"ebnf",
|
||||
"stop_token_ids",
|
||||
"no_stop_trim",
|
||||
"ignore_eos",
|
||||
"continue_final_message",
|
||||
"skip_special_tokens",
|
||||
"lora_path",
|
||||
"session_params",
|
||||
"separate_reasoning",
|
||||
"stream_reasoning",
|
||||
"chat_template_kwargs",
|
||||
"return_hidden_states",
|
||||
"repetition_penalty",
|
||||
"sampling_seed",
|
||||
] {
|
||||
obj.remove(key);
|
||||
}
|
||||
}
|
||||
|
||||
// Delegate to streaming or non-streaming handler
|
||||
if body.stream {
|
||||
handle_streaming_response(
|
||||
&self.client,
|
||||
&self.circuit_breaker,
|
||||
self.mcp_manager.as_ref(),
|
||||
self.response_storage.clone(),
|
||||
self.conversation_storage.clone(),
|
||||
self.conversation_item_storage.clone(),
|
||||
url,
|
||||
headers,
|
||||
payload,
|
||||
body,
|
||||
original_previous_response_id,
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
self.handle_non_streaming_response(
|
||||
url,
|
||||
headers,
|
||||
payload,
|
||||
body,
|
||||
original_previous_response_id,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_response(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
response_id: &str,
|
||||
_params: &ResponsesGetParams,
|
||||
) -> Response {
|
||||
let id = ResponseId::from(response_id);
|
||||
match self.response_storage.get_response(&id).await {
|
||||
Ok(Some(stored)) => {
|
||||
let mut response_json = stored.raw_response;
|
||||
if let Some(obj) = response_json.as_object_mut() {
|
||||
obj.insert("id".to_string(), json!(id.0));
|
||||
}
|
||||
(StatusCode::OK, Json(response_json)).into_response()
|
||||
}
|
||||
Ok(None) => (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({"error": "Response not found"})),
|
||||
)
|
||||
.into_response(),
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": format!("Failed to get response: {}", e)})),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn cancel_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response {
|
||||
// Forward cancellation to upstream
|
||||
let url = format!("{}/v1/responses/{}/cancel", self.base_url, response_id);
|
||||
let mut req = self.client.post(&url);
|
||||
|
||||
if let Some(h) = headers {
|
||||
req = apply_request_headers(h, req, false);
|
||||
}
|
||||
|
||||
match req.send().await {
|
||||
Ok(resp) => {
|
||||
let status = StatusCode::from_u16(resp.status().as_u16())
|
||||
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
match resp.text().await {
|
||||
Ok(body) => (status, body).into_response(),
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to read response: {}", e),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
Err(e) => (
|
||||
StatusCode::BAD_GATEWAY,
|
||||
format!("Failed to contact upstream: {}", e),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn route_embeddings(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &EmbeddingRequest,
|
||||
_model_id: Option<&str>,
|
||||
) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED, "Embeddings not supported").into_response()
|
||||
}
|
||||
|
||||
async fn route_rerank(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &RerankRequest,
|
||||
_model_id: Option<&str>,
|
||||
) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED, "Rerank not supported").into_response()
|
||||
}
|
||||
|
||||
async fn create_conversation(&self, _headers: Option<&HeaderMap>, body: &Value) -> Response {
|
||||
create_conversation(&self.conversation_storage, body.clone()).await
|
||||
}
|
||||
|
||||
async fn get_conversation(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
conversation_id: &str,
|
||||
) -> Response {
|
||||
get_conversation(&self.conversation_storage, conversation_id).await
|
||||
}
|
||||
|
||||
async fn update_conversation(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
conversation_id: &str,
|
||||
body: &Value,
|
||||
) -> Response {
|
||||
update_conversation(&self.conversation_storage, conversation_id, body.clone()).await
|
||||
}
|
||||
|
||||
async fn delete_conversation(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
conversation_id: &str,
|
||||
) -> Response {
|
||||
delete_conversation(&self.conversation_storage, conversation_id).await
|
||||
}
|
||||
|
||||
fn router_type(&self) -> &'static str {
|
||||
"openai"
|
||||
}
|
||||
|
||||
async fn list_conversation_items(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
conversation_id: &str,
|
||||
limit: Option<usize>,
|
||||
order: Option<String>,
|
||||
after: Option<String>,
|
||||
) -> Response {
|
||||
let mut query_params = std::collections::HashMap::new();
|
||||
query_params.insert("limit".to_string(), limit.unwrap_or(100).to_string());
|
||||
if let Some(after_val) = after {
|
||||
if !after_val.is_empty() {
|
||||
query_params.insert("after".to_string(), after_val);
|
||||
}
|
||||
}
|
||||
if let Some(order_val) = order {
|
||||
query_params.insert("order".to_string(), order_val);
|
||||
}
|
||||
|
||||
list_conversation_items(
|
||||
&self.conversation_storage,
|
||||
&self.conversation_item_storage,
|
||||
conversation_id,
|
||||
query_params,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
1550
sgl-router/src/routers/openai/streaming.rs
Normal file
1550
sgl-router/src/routers/openai/streaming.rs
Normal file
File diff suppressed because it is too large
Load Diff
100
sgl-router/src/routers/openai/utils.rs
Normal file
100
sgl-router/src/routers/openai/utils.rs
Normal file
@@ -0,0 +1,100 @@
|
||||
//! Utility types and constants for OpenAI router
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
// ============================================================================
|
||||
// SSE Event Type Constants
|
||||
// ============================================================================
|
||||
|
||||
/// SSE event type constants - single source of truth for event type strings
|
||||
pub(crate) mod event_types {
|
||||
// Response lifecycle events
|
||||
pub const RESPONSE_CREATED: &str = "response.created";
|
||||
pub const RESPONSE_IN_PROGRESS: &str = "response.in_progress";
|
||||
pub const RESPONSE_COMPLETED: &str = "response.completed";
|
||||
|
||||
// Output item events
|
||||
pub const OUTPUT_ITEM_ADDED: &str = "response.output_item.added";
|
||||
pub const OUTPUT_ITEM_DONE: &str = "response.output_item.done";
|
||||
pub const OUTPUT_ITEM_DELTA: &str = "response.output_item.delta";
|
||||
|
||||
// Function call events
|
||||
pub const FUNCTION_CALL_ARGUMENTS_DELTA: &str = "response.function_call_arguments.delta";
|
||||
pub const FUNCTION_CALL_ARGUMENTS_DONE: &str = "response.function_call_arguments.done";
|
||||
|
||||
// MCP call events
|
||||
pub const MCP_CALL_ARGUMENTS_DELTA: &str = "response.mcp_call_arguments.delta";
|
||||
pub const MCP_CALL_ARGUMENTS_DONE: &str = "response.mcp_call_arguments.done";
|
||||
pub const MCP_CALL_IN_PROGRESS: &str = "response.mcp_call.in_progress";
|
||||
pub const MCP_CALL_COMPLETED: &str = "response.mcp_call.completed";
|
||||
pub const MCP_LIST_TOOLS_IN_PROGRESS: &str = "response.mcp_list_tools.in_progress";
|
||||
pub const MCP_LIST_TOOLS_COMPLETED: &str = "response.mcp_list_tools.completed";
|
||||
|
||||
// Item types
|
||||
pub const ITEM_TYPE_FUNCTION_CALL: &str = "function_call";
|
||||
pub const ITEM_TYPE_FUNCTION_TOOL_CALL: &str = "function_tool_call";
|
||||
pub const ITEM_TYPE_MCP_CALL: &str = "mcp_call";
|
||||
pub const ITEM_TYPE_FUNCTION: &str = "function";
|
||||
pub const ITEM_TYPE_MCP_LIST_TOOLS: &str = "mcp_list_tools";
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Stream Action Enum
|
||||
// ============================================================================
|
||||
|
||||
/// Action to take based on streaming event processing
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum StreamAction {
|
||||
Forward, // Pass event to client
|
||||
Buffer, // Accumulate for tool execution
|
||||
ExecuteTools, // Function call complete, execute now
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Output Index Mapper
|
||||
// ============================================================================
|
||||
|
||||
/// Maps upstream output indices to sequential downstream indices
|
||||
#[derive(Debug, Default)]
|
||||
pub(crate) struct OutputIndexMapper {
|
||||
next_index: usize,
|
||||
// Map upstream output_index -> remapped output_index
|
||||
assigned: HashMap<usize, usize>,
|
||||
}
|
||||
|
||||
impl OutputIndexMapper {
|
||||
pub fn with_start(next_index: usize) -> Self {
|
||||
Self {
|
||||
next_index,
|
||||
assigned: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ensure_mapping(&mut self, upstream_index: usize) -> usize {
|
||||
*self.assigned.entry(upstream_index).or_insert_with(|| {
|
||||
let assigned = self.next_index;
|
||||
self.next_index += 1;
|
||||
assigned
|
||||
})
|
||||
}
|
||||
|
||||
pub fn lookup(&self, upstream_index: usize) -> Option<usize> {
|
||||
self.assigned.get(&upstream_index).copied()
|
||||
}
|
||||
|
||||
pub fn allocate_synthetic(&mut self) -> usize {
|
||||
let assigned = self.next_index;
|
||||
self.next_index += 1;
|
||||
assigned
|
||||
}
|
||||
|
||||
pub fn next_index(&self) -> usize {
|
||||
self.next_index
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Re-export FunctionCallInProgress from mcp module
|
||||
// ============================================================================
|
||||
|
||||
pub(crate) use super::mcp::FunctionCallInProgress;
|
||||
@@ -22,7 +22,7 @@ use sglang_router_rs::{
|
||||
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, ResponseInput,
|
||||
ResponsesGetParams, ResponsesRequest, UserMessageContent,
|
||||
},
|
||||
routers::{openai_router::OpenAIRouter, RouterTrait},
|
||||
routers::{openai::OpenAIRouter, RouterTrait},
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{
|
||||
|
||||
Reference in New Issue
Block a user