[router][grpc] Support v1/responses API (#11926)
This commit is contained in:
@@ -778,7 +778,9 @@ class OpenAIServingResponses(OpenAIServingChat):
|
||||
# Update the status to "cancelled"
|
||||
response.status = "cancelled"
|
||||
|
||||
# Abort the request
|
||||
# The response_id is the same as the rid used when submitting the request
|
||||
self.tokenizer_manager.abort_request(rid=response_id)
|
||||
|
||||
if task := self.background_tasks.get(response_id):
|
||||
task.cancel()
|
||||
try:
|
||||
|
||||
@@ -52,6 +52,9 @@ pub type ConversationMetadata = JsonMap<String, Value>;
|
||||
/// Input payload for creating a conversation
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct NewConversation {
|
||||
/// Optional conversation ID (if None, a random ID will be generated)
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub id: Option<ConversationId>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub metadata: Option<ConversationMetadata>,
|
||||
}
|
||||
@@ -68,7 +71,7 @@ pub struct Conversation {
|
||||
impl Conversation {
|
||||
pub fn new(new_conversation: NewConversation) -> Self {
|
||||
Self {
|
||||
id: ConversationId::new(),
|
||||
id: new_conversation.id.unwrap_or_default(),
|
||||
created_at: Utc::now(),
|
||||
metadata: new_conversation.metadata,
|
||||
}
|
||||
|
||||
@@ -180,21 +180,49 @@ impl McpClientManager {
|
||||
let backoff = ExponentialBackoffBuilder::new()
|
||||
.with_initial_interval(Duration::from_secs(1))
|
||||
.with_max_interval(Duration::from_secs(30))
|
||||
.with_max_elapsed_time(Some(Duration::from_secs(120)))
|
||||
.with_max_elapsed_time(Some(Duration::from_secs(30)))
|
||||
.build();
|
||||
|
||||
backoff::future::retry(backoff, || async {
|
||||
match Self::connect_server_impl(config).await {
|
||||
Ok(client) => Ok(client),
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to connect to '{}', retrying: {}", config.name, e);
|
||||
Err(backoff::Error::transient(e))
|
||||
if Self::is_permanent_error(&e) {
|
||||
tracing::error!(
|
||||
"Permanent error connecting to '{}': {} - not retrying",
|
||||
config.name,
|
||||
e
|
||||
);
|
||||
Err(backoff::Error::permanent(e))
|
||||
} else {
|
||||
tracing::warn!("Failed to connect to '{}', retrying: {}", config.name, e);
|
||||
Err(backoff::Error::transient(e))
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Determine if an error is permanent (should not retry) or transient (should retry)
|
||||
fn is_permanent_error(error: &McpError) -> bool {
|
||||
match error {
|
||||
McpError::Config(_) => true,
|
||||
McpError::Auth(_) => true,
|
||||
McpError::ServerNotFound(_) => true,
|
||||
McpError::Transport(_) => true,
|
||||
McpError::ConnectionFailed(msg) => {
|
||||
msg.contains("initialize")
|
||||
|| msg.contains("connection closed")
|
||||
|| msg.contains("connection refused")
|
||||
|| msg.contains("invalid URL")
|
||||
|| msg.contains("not found")
|
||||
}
|
||||
// Tool-related errors shouldn't occur during connection
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Internal implementation of server connection
|
||||
async fn connect_server_impl(
|
||||
config: &McpServerConfig,
|
||||
|
||||
@@ -411,6 +411,14 @@ fn default_repetition_penalty() -> f32 {
|
||||
1.0
|
||||
}
|
||||
|
||||
fn default_temperature() -> Option<f32> {
|
||||
Some(1.0)
|
||||
}
|
||||
|
||||
fn default_top_p() -> Option<f32> {
|
||||
Some(1.0)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Request/Response Types
|
||||
// ============================================================================
|
||||
@@ -477,7 +485,10 @@ pub struct ResponsesRequest {
|
||||
pub stream: Option<bool>,
|
||||
|
||||
/// Temperature for sampling
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[serde(
|
||||
default = "default_temperature",
|
||||
skip_serializing_if = "Option::is_none"
|
||||
)]
|
||||
pub temperature: Option<f32>,
|
||||
|
||||
/// Tool choice behavior
|
||||
@@ -493,7 +504,7 @@ pub struct ResponsesRequest {
|
||||
pub top_logprobs: Option<u32>,
|
||||
|
||||
/// Top-p sampling parameter
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[serde(default = "default_top_p", skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f32>,
|
||||
|
||||
/// Truncation behavior
|
||||
|
||||
@@ -6,6 +6,7 @@ pub mod context;
|
||||
pub mod pd_router;
|
||||
pub mod pipeline;
|
||||
pub mod processing;
|
||||
pub mod responses;
|
||||
pub mod router;
|
||||
pub mod streaming;
|
||||
pub mod utils;
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
//! that transform a RequestContext through its lifecycle.
|
||||
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
collections::HashMap,
|
||||
sync::Arc,
|
||||
time::{Instant, SystemTime, UNIX_EPOCH},
|
||||
};
|
||||
@@ -12,15 +14,20 @@ use async_trait::async_trait;
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use proto::DisaggregatedParams;
|
||||
use rand::Rng;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{debug, error, warn};
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::{context::*, processing, streaming, utils};
|
||||
use super::{context::*, processing, responses::BackgroundTaskInfo, streaming, utils};
|
||||
use crate::{
|
||||
core::{ConnectionMode, Worker, WorkerRegistry, WorkerType},
|
||||
grpc_client::proto,
|
||||
policies::PolicyRegistry,
|
||||
protocols::{chat::ChatCompletionRequest, common::InputIds, generate::GenerateRequest},
|
||||
protocols::{
|
||||
chat::{ChatCompletionRequest, ChatCompletionResponse},
|
||||
common::InputIds,
|
||||
generate::GenerateRequest,
|
||||
},
|
||||
reasoning_parser::ParserFactory as ReasoningParserFactory,
|
||||
tokenizer::traits::Tokenizer,
|
||||
tool_parser::ParserFactory as ToolParserFactory,
|
||||
@@ -131,7 +138,7 @@ impl PreparationStage {
|
||||
token_ids,
|
||||
processed_messages: Some(processed_messages),
|
||||
tool_constraints: tool_call_constraint,
|
||||
filtered_request: if matches!(body_ref, std::borrow::Cow::Owned(_)) {
|
||||
filtered_request: if matches!(body_ref, Cow::Owned(_)) {
|
||||
Some(body_ref.into_owned())
|
||||
} else {
|
||||
None
|
||||
@@ -1090,4 +1097,86 @@ impl RequestPipeline {
|
||||
None => utils::internal_error_static("No response produced"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute chat pipeline for responses endpoint (Result-based for easier composition)
|
||||
///
|
||||
/// This is used by the responses module and returns Result instead of Response.
|
||||
/// It also supports background mode cancellation via background_tasks.
|
||||
pub async fn execute_chat_for_responses(
|
||||
&self,
|
||||
request: Arc<ChatCompletionRequest>,
|
||||
headers: Option<http::HeaderMap>,
|
||||
model_id: Option<String>,
|
||||
components: Arc<SharedComponents>,
|
||||
response_id: Option<String>,
|
||||
background_tasks: Option<Arc<RwLock<HashMap<String, BackgroundTaskInfo>>>>,
|
||||
) -> Result<ChatCompletionResponse, String> {
|
||||
let mut ctx = RequestContext::for_chat(request, headers, model_id, components);
|
||||
|
||||
// Execute each stage in sequence
|
||||
for (idx, stage) in self.stages.iter().enumerate() {
|
||||
match stage.execute(&mut ctx).await {
|
||||
Ok(Some(_response)) => {
|
||||
// Streaming not supported for responses sync mode
|
||||
return Err("Streaming is not supported in this context".to_string());
|
||||
}
|
||||
Ok(None) => {
|
||||
let stage_name = stage.name();
|
||||
|
||||
// After ClientAcquisitionStage, store client for background task cancellation
|
||||
if stage_name == "ClientAcquisition" {
|
||||
if let (Some(ref clients), Some(ref resp_id), Some(ref tasks)) =
|
||||
(&ctx.state.clients, &response_id, &background_tasks)
|
||||
{
|
||||
let client_to_store = match clients {
|
||||
ClientSelection::Single { client } => client.clone(),
|
||||
ClientSelection::Dual { decode, .. } => decode.clone(),
|
||||
};
|
||||
|
||||
if let Some(task_info) = tasks.write().await.get_mut(resp_id.as_str()) {
|
||||
*task_info.client.write().await = Some(client_to_store);
|
||||
debug!("Stored client for response_id: {}", resp_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// After DispatchMetadataStage, store grpc_request_id for background task cancellation
|
||||
if stage_name == "DispatchMetadata" {
|
||||
if let (Some(ref dispatch), Some(ref resp_id), Some(ref tasks)) =
|
||||
(&ctx.state.dispatch, &response_id, &background_tasks)
|
||||
{
|
||||
let grpc_request_id = dispatch.request_id.clone();
|
||||
|
||||
if let Some(task_info) = tasks.write().await.get_mut(resp_id.as_str()) {
|
||||
task_info.grpc_request_id = grpc_request_id.clone();
|
||||
debug!("Stored grpc_request_id for response_id: {}", resp_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Continue to next stage
|
||||
continue;
|
||||
}
|
||||
Err(response) => {
|
||||
// Error occurred
|
||||
error!(
|
||||
"Stage {} ({}) failed with status {}",
|
||||
idx + 1,
|
||||
stage.name(),
|
||||
response.status()
|
||||
);
|
||||
return Err(format!("Pipeline stage {} failed", stage.name()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract final response
|
||||
match ctx.state.response.final_response {
|
||||
Some(FinalResponse::Chat(response)) => Ok(response),
|
||||
Some(FinalResponse::Generate(_)) => {
|
||||
Err("Internal error: wrong response type".to_string())
|
||||
}
|
||||
None => Err("No response produced".to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -408,10 +408,7 @@ impl ResponseProcessor {
|
||||
tool_type: "function".to_string(),
|
||||
function: FunctionCallResponse {
|
||||
name: tc.function.name,
|
||||
arguments: Some(
|
||||
serde_json::to_string(&tc.function.arguments)
|
||||
.unwrap_or_else(|_| "{}".to_string()),
|
||||
),
|
||||
arguments: Some(tc.function.arguments),
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
365
sgl-router/src/routers/grpc/responses/conversions.rs
Normal file
365
sgl-router/src/routers/grpc/responses/conversions.rs
Normal file
@@ -0,0 +1,365 @@
|
||||
//! Conversion utilities for translating between /v1/responses and /v1/chat/completions formats
|
||||
//!
|
||||
//! This module implements the conversion approach where:
|
||||
//! 1. ResponsesRequest → ChatCompletionRequest (for backend processing)
|
||||
//! 2. ChatCompletionResponse → ResponsesResponse (for client response)
|
||||
//!
|
||||
//! This allows the gRPC router to reuse the existing chat pipeline infrastructure
|
||||
//! without requiring Python backend changes.
|
||||
|
||||
use crate::protocols::{
|
||||
chat::{ChatCompletionRequest, ChatCompletionResponse, ChatMessage, UserMessageContent},
|
||||
common::{FunctionCallResponse, StreamOptions, ToolCall, UsageInfo},
|
||||
responses::{
|
||||
ResponseContentPart, ResponseInput, ResponseInputOutputItem, ResponseOutputItem,
|
||||
ResponseStatus, ResponsesRequest, ResponsesResponse, ResponsesUsage,
|
||||
},
|
||||
};
|
||||
|
||||
/// Convert a ResponsesRequest to ChatCompletionRequest for processing through the chat pipeline
|
||||
///
|
||||
/// # Conversion Logic
|
||||
/// - `input` (text/items) → `messages` (chat messages)
|
||||
/// - `instructions` → system message (prepended)
|
||||
/// - `max_output_tokens` → `max_completion_tokens`
|
||||
/// - Tool-related fields are passed through
|
||||
/// - Response-specific fields (previous_response_id, conversation) are handled by router
|
||||
pub fn responses_to_chat(req: &ResponsesRequest) -> Result<ChatCompletionRequest, String> {
|
||||
let mut messages = Vec::new();
|
||||
|
||||
// 1. Add system message if instructions provided
|
||||
if let Some(instructions) = &req.instructions {
|
||||
messages.push(ChatMessage::System {
|
||||
content: instructions.clone(),
|
||||
name: None,
|
||||
});
|
||||
}
|
||||
|
||||
// 2. Convert input to chat messages
|
||||
match &req.input {
|
||||
ResponseInput::Text(text) => {
|
||||
// Simple text input → user message
|
||||
messages.push(ChatMessage::User {
|
||||
content: UserMessageContent::Text(text.clone()),
|
||||
name: None,
|
||||
});
|
||||
}
|
||||
ResponseInput::Items(items) => {
|
||||
// Structured items → convert each to appropriate chat message
|
||||
for item in items {
|
||||
match item {
|
||||
ResponseInputOutputItem::Message { role, content, .. } => {
|
||||
// Extract text from content parts
|
||||
let text = extract_text_from_content(content);
|
||||
|
||||
match role.as_str() {
|
||||
"user" => {
|
||||
messages.push(ChatMessage::User {
|
||||
content: UserMessageContent::Text(text),
|
||||
name: None,
|
||||
});
|
||||
}
|
||||
"assistant" => {
|
||||
messages.push(ChatMessage::Assistant {
|
||||
content: Some(text),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
reasoning_content: None,
|
||||
});
|
||||
}
|
||||
"system" => {
|
||||
messages.push(ChatMessage::System {
|
||||
content: text,
|
||||
name: None,
|
||||
});
|
||||
}
|
||||
_ => {
|
||||
// Unknown role, treat as user message
|
||||
messages.push(ChatMessage::User {
|
||||
content: UserMessageContent::Text(text),
|
||||
name: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
ResponseInputOutputItem::FunctionToolCall {
|
||||
id,
|
||||
name,
|
||||
arguments,
|
||||
output,
|
||||
..
|
||||
} => {
|
||||
// Tool call from history - add as assistant message with tool call
|
||||
// followed by tool response if output exists
|
||||
|
||||
// Add assistant message with tool_calls (the LLM's decision)
|
||||
messages.push(ChatMessage::Assistant {
|
||||
content: None,
|
||||
name: None,
|
||||
tool_calls: Some(vec![ToolCall {
|
||||
id: id.clone(),
|
||||
tool_type: "function".to_string(),
|
||||
function: FunctionCallResponse {
|
||||
name: name.clone(),
|
||||
arguments: Some(arguments.clone()),
|
||||
},
|
||||
}]),
|
||||
reasoning_content: None,
|
||||
});
|
||||
|
||||
// Add tool result message if output exists
|
||||
if let Some(output_text) = output {
|
||||
messages.push(ChatMessage::Tool {
|
||||
content: output_text.clone(),
|
||||
tool_call_id: id.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
ResponseInputOutputItem::Reasoning { content, .. } => {
|
||||
// Reasoning content - add as assistant message with reasoning_content
|
||||
let reasoning_text = content
|
||||
.iter()
|
||||
.map(|c| match c {
|
||||
crate::protocols::responses::ResponseReasoningContent::ReasoningText { text } => {
|
||||
text.as_str()
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
messages.push(ChatMessage::Assistant {
|
||||
content: None,
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
reasoning_content: Some(reasoning_text),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure we have at least one message
|
||||
if messages.is_empty() {
|
||||
return Err("Request must contain at least one message".to_string());
|
||||
}
|
||||
|
||||
// 3. Build ChatCompletionRequest
|
||||
let is_streaming = req.stream.unwrap_or(false);
|
||||
|
||||
Ok(ChatCompletionRequest {
|
||||
messages,
|
||||
model: req.model.clone().unwrap_or_else(|| "default".to_string()),
|
||||
temperature: req.temperature,
|
||||
max_completion_tokens: req.max_output_tokens,
|
||||
stream: is_streaming,
|
||||
stream_options: if is_streaming {
|
||||
Some(StreamOptions {
|
||||
include_usage: Some(true),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
},
|
||||
parallel_tool_calls: req.parallel_tool_calls,
|
||||
top_logprobs: req.top_logprobs,
|
||||
top_p: req.top_p,
|
||||
skip_special_tokens: true, // Always skip special tokens // TODO: except for gpt-oss
|
||||
// Note: tools and tool_choice will be handled separately for MCP transformation
|
||||
tools: None, // Will be set by caller if needed
|
||||
tool_choice: None, // Will be set by caller if needed
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
/// Extract text content from ResponseContentPart array
|
||||
fn extract_text_from_content(content: &[ResponseContentPart]) -> String {
|
||||
content
|
||||
.iter()
|
||||
.filter_map(|part| match part {
|
||||
ResponseContentPart::InputText { text } => Some(text.as_str()),
|
||||
ResponseContentPart::OutputText { text, .. } => Some(text.as_str()),
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("")
|
||||
}
|
||||
|
||||
/// Convert a ChatCompletionResponse to ResponsesResponse
|
||||
///
|
||||
/// # Conversion Logic
|
||||
/// - `id` → `id` (pass through)
|
||||
/// - `model` → `model` (pass through)
|
||||
/// - `choices[0].message` → `output` array (convert to ResponseOutputItem::Message)
|
||||
/// - `choices[0].finish_reason` → determines `status` (stop/length → Completed)
|
||||
/// - `created` timestamp → `created_at`
|
||||
pub fn chat_to_responses(
|
||||
chat_resp: &ChatCompletionResponse,
|
||||
original_req: &ResponsesRequest,
|
||||
) -> Result<ResponsesResponse, String> {
|
||||
// Extract the first choice (responses API doesn't support n>1)
|
||||
let choice = chat_resp
|
||||
.choices
|
||||
.first()
|
||||
.ok_or_else(|| "Chat response contains no choices".to_string())?;
|
||||
|
||||
// Convert assistant message to output items
|
||||
let mut output: Vec<ResponseOutputItem> = Vec::new();
|
||||
|
||||
// Convert message content to output item
|
||||
if let Some(content) = &choice.message.content {
|
||||
if !content.is_empty() {
|
||||
output.push(ResponseOutputItem::Message {
|
||||
id: format!("msg_{}", chat_resp.id),
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ResponseContentPart::OutputText {
|
||||
text: content.clone(),
|
||||
annotations: vec![],
|
||||
logprobs: choice.logprobs.clone(),
|
||||
}],
|
||||
status: "completed".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Convert reasoning content if present (O1-style models)
|
||||
if let Some(reasoning) = &choice.message.reasoning_content {
|
||||
if !reasoning.is_empty() {
|
||||
output.push(ResponseOutputItem::Reasoning {
|
||||
id: format!("reasoning_{}", chat_resp.id),
|
||||
summary: vec![],
|
||||
content: vec![
|
||||
crate::protocols::responses::ResponseReasoningContent::ReasoningText {
|
||||
text: reasoning.clone(),
|
||||
},
|
||||
],
|
||||
status: Some("completed".to_string()),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Convert tool calls if present
|
||||
if let Some(tool_calls) = &choice.message.tool_calls {
|
||||
for tool_call in tool_calls {
|
||||
output.push(ResponseOutputItem::FunctionToolCall {
|
||||
id: tool_call.id.clone(),
|
||||
name: tool_call.function.name.clone(),
|
||||
arguments: tool_call.function.arguments.clone().unwrap_or_default(),
|
||||
output: None, // Tool hasn't been executed yet
|
||||
status: "in_progress".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Determine response status based on finish_reason
|
||||
let status = match choice.finish_reason.as_deref() {
|
||||
Some("stop") | Some("length") => ResponseStatus::Completed,
|
||||
Some("tool_calls") => ResponseStatus::InProgress, // Waiting for tool execution
|
||||
Some("failed") | Some("error") => ResponseStatus::Failed,
|
||||
_ => ResponseStatus::Completed, // Default to completed
|
||||
};
|
||||
|
||||
// Convert usage from Usage to UsageInfo, then wrap in ResponsesUsage
|
||||
let usage = chat_resp.usage.as_ref().map(|u| {
|
||||
let usage_info = UsageInfo {
|
||||
prompt_tokens: u.prompt_tokens,
|
||||
completion_tokens: u.completion_tokens,
|
||||
total_tokens: u.total_tokens,
|
||||
reasoning_tokens: u
|
||||
.completion_tokens_details
|
||||
.as_ref()
|
||||
.and_then(|d| d.reasoning_tokens),
|
||||
prompt_tokens_details: None, // Chat response doesn't have this
|
||||
};
|
||||
ResponsesUsage::Classic(usage_info)
|
||||
});
|
||||
|
||||
// Generate response
|
||||
Ok(ResponsesResponse {
|
||||
id: chat_resp.id.clone(),
|
||||
object: "response".to_string(),
|
||||
created_at: chat_resp.created as i64,
|
||||
status,
|
||||
error: None,
|
||||
incomplete_details: None,
|
||||
instructions: original_req.instructions.clone(),
|
||||
max_output_tokens: original_req.max_output_tokens,
|
||||
model: chat_resp.model.clone(),
|
||||
output,
|
||||
parallel_tool_calls: original_req.parallel_tool_calls.unwrap_or(true),
|
||||
previous_response_id: original_req.previous_response_id.clone(),
|
||||
reasoning: None, // TODO: Map reasoning effort if needed
|
||||
store: original_req.store.unwrap_or(true),
|
||||
temperature: original_req.temperature,
|
||||
text: None,
|
||||
tool_choice: "auto".to_string(), // TODO: Map from original request
|
||||
tools: original_req.tools.clone().unwrap_or_default(),
|
||||
top_p: original_req.top_p,
|
||||
truncation: None,
|
||||
usage,
|
||||
user: None, // No user field in chat response
|
||||
metadata: original_req.metadata.clone().unwrap_or_default(),
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_text_input_conversion() {
|
||||
let req = ResponsesRequest {
|
||||
input: ResponseInput::Text("Hello, world!".to_string()),
|
||||
instructions: Some("You are a helpful assistant.".to_string()),
|
||||
model: Some("gpt-4".to_string()),
|
||||
temperature: Some(0.7),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let chat_req = responses_to_chat(&req).unwrap();
|
||||
assert_eq!(chat_req.messages.len(), 2); // system + user
|
||||
assert_eq!(chat_req.model, "gpt-4");
|
||||
assert_eq!(chat_req.temperature, Some(0.7));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_items_input_conversion() {
|
||||
let req = ResponsesRequest {
|
||||
input: ResponseInput::Items(vec![
|
||||
ResponseInputOutputItem::Message {
|
||||
id: "msg_1".to_string(),
|
||||
role: "user".to_string(),
|
||||
content: vec![ResponseContentPart::InputText {
|
||||
text: "Hello!".to_string(),
|
||||
}],
|
||||
status: None,
|
||||
},
|
||||
ResponseInputOutputItem::Message {
|
||||
id: "msg_2".to_string(),
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ResponseContentPart::OutputText {
|
||||
text: "Hi there!".to_string(),
|
||||
annotations: vec![],
|
||||
logprobs: None,
|
||||
}],
|
||||
status: None,
|
||||
},
|
||||
]),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let chat_req = responses_to_chat(&req).unwrap();
|
||||
assert_eq!(chat_req.messages.len(), 2); // user + assistant
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_input_error() {
|
||||
let req = ResponsesRequest {
|
||||
input: ResponseInput::Text("".to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Empty text should still create a user message, so this should succeed
|
||||
let result = responses_to_chat(&req);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
}
|
||||
1290
sgl-router/src/routers/grpc/responses/handlers.rs
Normal file
1290
sgl-router/src/routers/grpc/responses/handlers.rs
Normal file
File diff suppressed because it is too large
Load Diff
20
sgl-router/src/routers/grpc/responses/mod.rs
Normal file
20
sgl-router/src/routers/grpc/responses/mod.rs
Normal file
@@ -0,0 +1,20 @@
|
||||
//! gRPC Router `/v1/responses` endpoint implementation
|
||||
//!
|
||||
//! This module handles all responses-specific logic including:
|
||||
//! - Request validation
|
||||
//! - Conversation history and response chain loading
|
||||
//! - Background mode execution
|
||||
//! - Streaming support
|
||||
//! - MCP tool loop wrapper
|
||||
//! - Response persistence
|
||||
|
||||
// Module declarations
|
||||
mod conversions;
|
||||
mod handlers;
|
||||
pub mod streaming;
|
||||
pub mod tool_loop;
|
||||
pub mod types;
|
||||
|
||||
// Public exports
|
||||
pub use handlers::{cancel_response_impl, get_response_impl, route_responses};
|
||||
pub use types::BackgroundTaskInfo;
|
||||
574
sgl-router/src/routers/grpc/responses/streaming.rs
Normal file
574
sgl-router/src/routers/grpc/responses/streaming.rs
Normal file
@@ -0,0 +1,574 @@
|
||||
//! Streaming infrastructure for /v1/responses endpoint
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use bytes::Bytes;
|
||||
use serde_json::json;
|
||||
use tokio::sync::mpsc;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::protocols::chat::ChatCompletionStreamResponse;
|
||||
|
||||
pub(super) enum OutputItemType {
|
||||
Message,
|
||||
McpListTools,
|
||||
McpCall,
|
||||
Reasoning,
|
||||
}
|
||||
|
||||
/// Status of an output item
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
enum ItemStatus {
|
||||
InProgress,
|
||||
Completed,
|
||||
}
|
||||
|
||||
/// State tracking for a single output item
|
||||
#[derive(Debug, Clone)]
|
||||
struct OutputItemState {
|
||||
output_index: usize,
|
||||
status: ItemStatus,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Streaming Event Emitter
|
||||
// ============================================================================
|
||||
|
||||
/// OpenAI-compatible event emitter for /v1/responses streaming
|
||||
///
|
||||
/// Manages state and sequence numbers to emit proper event types:
|
||||
/// - response.created
|
||||
/// - response.in_progress
|
||||
/// - response.output_item.added
|
||||
/// - response.content_part.added
|
||||
/// - response.output_text.delta (multiple)
|
||||
/// - response.output_text.done
|
||||
/// - response.content_part.done
|
||||
/// - response.output_item.done
|
||||
/// - response.completed
|
||||
/// - response.mcp_list_tools.in_progress
|
||||
/// - response.mcp_list_tools.completed
|
||||
/// - response.mcp_call.in_progress
|
||||
/// - response.mcp_call_arguments.delta
|
||||
/// - response.mcp_call_arguments.done
|
||||
/// - response.mcp_call.completed
|
||||
/// - response.mcp_call.failed
|
||||
pub(super) struct ResponseStreamEventEmitter {
|
||||
sequence_number: u64,
|
||||
response_id: String,
|
||||
model: String,
|
||||
created_at: u64,
|
||||
message_id: String,
|
||||
accumulated_text: String,
|
||||
has_emitted_created: bool,
|
||||
has_emitted_in_progress: bool,
|
||||
has_emitted_output_item_added: bool,
|
||||
has_emitted_content_part_added: bool,
|
||||
// MCP call tracking
|
||||
mcp_call_accumulated_args: HashMap<String, String>,
|
||||
// Output item tracking (NEW)
|
||||
output_items: Vec<OutputItemState>,
|
||||
next_output_index: usize,
|
||||
current_message_output_index: Option<usize>, // Tracks output_index of current message
|
||||
current_item_id: Option<String>, // Tracks item_id of current item
|
||||
}
|
||||
|
||||
impl ResponseStreamEventEmitter {
|
||||
pub(super) fn new(response_id: String, model: String, created_at: u64) -> Self {
|
||||
let message_id = format!("msg_{}", Uuid::new_v4());
|
||||
|
||||
Self {
|
||||
sequence_number: 0,
|
||||
response_id,
|
||||
model,
|
||||
created_at,
|
||||
message_id,
|
||||
accumulated_text: String::new(),
|
||||
has_emitted_created: false,
|
||||
has_emitted_in_progress: false,
|
||||
has_emitted_output_item_added: false,
|
||||
has_emitted_content_part_added: false,
|
||||
mcp_call_accumulated_args: HashMap::new(),
|
||||
output_items: Vec::new(),
|
||||
next_output_index: 0,
|
||||
current_message_output_index: None,
|
||||
current_item_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn next_sequence(&mut self) -> u64 {
|
||||
let seq = self.sequence_number;
|
||||
self.sequence_number += 1;
|
||||
seq
|
||||
}
|
||||
|
||||
pub(super) fn emit_created(&mut self) -> serde_json::Value {
|
||||
self.has_emitted_created = true;
|
||||
json!({
|
||||
"type": "response.created",
|
||||
"sequence_number": self.next_sequence(),
|
||||
"response": {
|
||||
"id": self.response_id,
|
||||
"object": "response",
|
||||
"created_at": self.created_at,
|
||||
"status": "in_progress",
|
||||
"model": self.model,
|
||||
"output": []
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub(super) fn emit_in_progress(&mut self) -> serde_json::Value {
|
||||
self.has_emitted_in_progress = true;
|
||||
json!({
|
||||
"type": "response.in_progress",
|
||||
"sequence_number": self.next_sequence(),
|
||||
"response": {
|
||||
"id": self.response_id,
|
||||
"object": "response",
|
||||
"status": "in_progress"
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub(super) fn emit_content_part_added(
|
||||
&mut self,
|
||||
output_index: usize,
|
||||
item_id: &str,
|
||||
content_index: usize,
|
||||
) -> serde_json::Value {
|
||||
self.has_emitted_content_part_added = true;
|
||||
json!({
|
||||
"type": "response.content_part.added",
|
||||
"sequence_number": self.next_sequence(),
|
||||
"output_index": output_index,
|
||||
"item_id": item_id,
|
||||
"content_index": content_index,
|
||||
"part": {
|
||||
"type": "text",
|
||||
"text": ""
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub(super) fn emit_text_delta(
|
||||
&mut self,
|
||||
delta: &str,
|
||||
output_index: usize,
|
||||
item_id: &str,
|
||||
content_index: usize,
|
||||
) -> serde_json::Value {
|
||||
self.accumulated_text.push_str(delta);
|
||||
json!({
|
||||
"type": "response.output_text.delta",
|
||||
"sequence_number": self.next_sequence(),
|
||||
"output_index": output_index,
|
||||
"item_id": item_id,
|
||||
"content_index": content_index,
|
||||
"delta": delta
|
||||
})
|
||||
}
|
||||
|
||||
pub(super) fn emit_text_done(
|
||||
&mut self,
|
||||
output_index: usize,
|
||||
item_id: &str,
|
||||
content_index: usize,
|
||||
) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "response.output_text.done",
|
||||
"sequence_number": self.next_sequence(),
|
||||
"output_index": output_index,
|
||||
"item_id": item_id,
|
||||
"content_index": content_index,
|
||||
"text": self.accumulated_text.clone()
|
||||
})
|
||||
}
|
||||
|
||||
pub(super) fn emit_content_part_done(
|
||||
&mut self,
|
||||
output_index: usize,
|
||||
item_id: &str,
|
||||
content_index: usize,
|
||||
) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "response.content_part.done",
|
||||
"sequence_number": self.next_sequence(),
|
||||
"output_index": output_index,
|
||||
"item_id": item_id,
|
||||
"content_index": content_index,
|
||||
"part": {
|
||||
"type": "text",
|
||||
"text": self.accumulated_text.clone()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub(super) fn emit_completed(
|
||||
&mut self,
|
||||
usage: Option<&serde_json::Value>,
|
||||
) -> serde_json::Value {
|
||||
let mut response = json!({
|
||||
"type": "response.completed",
|
||||
"sequence_number": self.next_sequence(),
|
||||
"response": {
|
||||
"id": self.response_id,
|
||||
"object": "response",
|
||||
"created_at": self.created_at,
|
||||
"status": "completed",
|
||||
"model": self.model,
|
||||
"output": [{
|
||||
"id": self.message_id.clone(),
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": self.accumulated_text.clone()
|
||||
}]
|
||||
}]
|
||||
}
|
||||
});
|
||||
|
||||
if let Some(usage_val) = usage {
|
||||
response["response"]["usage"] = usage_val.clone();
|
||||
}
|
||||
|
||||
response
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// MCP Event Emission Methods
|
||||
// ========================================================================
|
||||
|
||||
pub(super) fn emit_mcp_list_tools_in_progress(
|
||||
&mut self,
|
||||
output_index: usize,
|
||||
) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "response.mcp_list_tools.in_progress",
|
||||
"sequence_number": self.next_sequence(),
|
||||
"output_index": output_index
|
||||
})
|
||||
}
|
||||
|
||||
pub(super) fn emit_mcp_list_tools_completed(
|
||||
&mut self,
|
||||
output_index: usize,
|
||||
tools: &[crate::mcp::ToolInfo],
|
||||
) -> serde_json::Value {
|
||||
let tool_items: Vec<_> = tools
|
||||
.iter()
|
||||
.map(|t| {
|
||||
json!({
|
||||
"name": t.name,
|
||||
"description": t.description,
|
||||
"input_schema": t.parameters.clone().unwrap_or_else(|| json!({
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}))
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
json!({
|
||||
"type": "response.mcp_list_tools.completed",
|
||||
"sequence_number": self.next_sequence(),
|
||||
"output_index": output_index,
|
||||
"tools": tool_items
|
||||
})
|
||||
}
|
||||
|
||||
pub(super) fn emit_mcp_call_in_progress(
|
||||
&mut self,
|
||||
output_index: usize,
|
||||
item_id: &str,
|
||||
) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "response.mcp_call.in_progress",
|
||||
"sequence_number": self.next_sequence(),
|
||||
"output_index": output_index,
|
||||
"item_id": item_id
|
||||
})
|
||||
}
|
||||
|
||||
pub(super) fn emit_mcp_call_arguments_delta(
|
||||
&mut self,
|
||||
output_index: usize,
|
||||
item_id: &str,
|
||||
delta: &str,
|
||||
) -> serde_json::Value {
|
||||
// Accumulate arguments for this call
|
||||
self.mcp_call_accumulated_args
|
||||
.entry(item_id.to_string())
|
||||
.or_default()
|
||||
.push_str(delta);
|
||||
|
||||
json!({
|
||||
"type": "response.mcp_call_arguments.delta",
|
||||
"sequence_number": self.next_sequence(),
|
||||
"output_index": output_index,
|
||||
"item_id": item_id,
|
||||
"delta": delta
|
||||
})
|
||||
}
|
||||
|
||||
pub(super) fn emit_mcp_call_arguments_done(
|
||||
&mut self,
|
||||
output_index: usize,
|
||||
item_id: &str,
|
||||
arguments: &str,
|
||||
) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "response.mcp_call_arguments.done",
|
||||
"sequence_number": self.next_sequence(),
|
||||
"output_index": output_index,
|
||||
"item_id": item_id,
|
||||
"arguments": arguments
|
||||
})
|
||||
}
|
||||
|
||||
pub(super) fn emit_mcp_call_completed(
|
||||
&mut self,
|
||||
output_index: usize,
|
||||
item_id: &str,
|
||||
) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "response.mcp_call.completed",
|
||||
"sequence_number": self.next_sequence(),
|
||||
"output_index": output_index,
|
||||
"item_id": item_id
|
||||
})
|
||||
}
|
||||
|
||||
pub(super) fn emit_mcp_call_failed(
|
||||
&mut self,
|
||||
output_index: usize,
|
||||
item_id: &str,
|
||||
error: &str,
|
||||
) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "response.mcp_call.failed",
|
||||
"sequence_number": self.next_sequence(),
|
||||
"output_index": output_index,
|
||||
"item_id": item_id,
|
||||
"error": error
|
||||
})
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Output Item Wrapper Events
|
||||
// ========================================================================
|
||||
|
||||
/// Emit response.output_item.added event
|
||||
pub(super) fn emit_output_item_added(
|
||||
&mut self,
|
||||
output_index: usize,
|
||||
item: &serde_json::Value,
|
||||
) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "response.output_item.added",
|
||||
"sequence_number": self.next_sequence(),
|
||||
"output_index": output_index,
|
||||
"item": item
|
||||
})
|
||||
}
|
||||
|
||||
/// Emit response.output_item.done event
|
||||
pub(super) fn emit_output_item_done(
|
||||
&mut self,
|
||||
output_index: usize,
|
||||
item: &serde_json::Value,
|
||||
) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "response.output_item.done",
|
||||
"sequence_number": self.next_sequence(),
|
||||
"output_index": output_index,
|
||||
"item": item
|
||||
})
|
||||
}
|
||||
|
||||
/// Generate unique ID for item type
|
||||
fn generate_item_id(prefix: &str) -> String {
|
||||
format!("{}_{}", prefix, Uuid::new_v4().to_string().replace("-", ""))
|
||||
}
|
||||
|
||||
/// Allocate next output index and track item
|
||||
pub(super) fn allocate_output_index(&mut self, item_type: OutputItemType) -> (usize, String) {
|
||||
let index = self.next_output_index;
|
||||
self.next_output_index += 1;
|
||||
|
||||
let id_prefix = match &item_type {
|
||||
OutputItemType::McpListTools => "mcpl",
|
||||
OutputItemType::McpCall => "mcp",
|
||||
OutputItemType::Message => "msg",
|
||||
OutputItemType::Reasoning => "rs",
|
||||
};
|
||||
|
||||
let id = Self::generate_item_id(id_prefix);
|
||||
|
||||
self.output_items.push(OutputItemState {
|
||||
output_index: index,
|
||||
status: ItemStatus::InProgress,
|
||||
});
|
||||
|
||||
(index, id)
|
||||
}
|
||||
|
||||
/// Mark output item as completed
|
||||
pub(super) fn complete_output_item(&mut self, output_index: usize) {
|
||||
if let Some(item) = self
|
||||
.output_items
|
||||
.iter_mut()
|
||||
.find(|i| i.output_index == output_index)
|
||||
{
|
||||
item.status = ItemStatus::Completed;
|
||||
}
|
||||
}
|
||||
|
||||
/// Emit reasoning item wrapper events (added + done)
|
||||
///
|
||||
/// Reasoning items in OpenAI format are simple placeholders emitted between tool iterations.
|
||||
/// They don't have streaming content - just wrapper events with empty/null content.
|
||||
pub(super) fn emit_reasoning_item(
|
||||
&mut self,
|
||||
tx: &mpsc::UnboundedSender<Result<Bytes, std::io::Error>>,
|
||||
reasoning_content: Option<String>,
|
||||
) -> Result<(), String> {
|
||||
// Allocate output index and generate ID
|
||||
let (output_index, item_id) = self.allocate_output_index(OutputItemType::Reasoning);
|
||||
|
||||
// Build reasoning item structure
|
||||
let item = json!({
|
||||
"id": item_id,
|
||||
"type": "reasoning",
|
||||
"summary": [],
|
||||
"content": reasoning_content,
|
||||
"encrypted_content": null,
|
||||
"status": null
|
||||
});
|
||||
|
||||
// Emit output_item.added
|
||||
let added_event = self.emit_output_item_added(output_index, &item);
|
||||
self.send_event(&added_event, tx)?;
|
||||
|
||||
// Immediately emit output_item.done (no streaming for reasoning)
|
||||
let done_event = self.emit_output_item_done(output_index, &item);
|
||||
self.send_event(&done_event, tx)?;
|
||||
|
||||
// Mark as completed
|
||||
self.complete_output_item(output_index);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Process a chunk and emit appropriate events
|
||||
pub(super) fn process_chunk(
|
||||
&mut self,
|
||||
chunk: &ChatCompletionStreamResponse,
|
||||
tx: &mpsc::UnboundedSender<Result<Bytes, std::io::Error>>,
|
||||
) -> Result<(), String> {
|
||||
// Process content if present
|
||||
if let Some(choice) = chunk.choices.first() {
|
||||
if let Some(content) = &choice.delta.content {
|
||||
if !content.is_empty() {
|
||||
// Allocate output_index and item_id for this message item (once per message)
|
||||
if self.current_item_id.is_none() {
|
||||
let (output_index, item_id) =
|
||||
self.allocate_output_index(OutputItemType::Message);
|
||||
|
||||
// Build message item structure
|
||||
let item = json!({
|
||||
"id": item_id,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": []
|
||||
});
|
||||
|
||||
// Emit output_item.added
|
||||
let event = self.emit_output_item_added(output_index, &item);
|
||||
self.send_event(&event, tx)?;
|
||||
self.has_emitted_output_item_added = true;
|
||||
|
||||
// Store for subsequent events
|
||||
self.current_item_id = Some(item_id);
|
||||
self.current_message_output_index = Some(output_index);
|
||||
}
|
||||
|
||||
let output_index = self.current_message_output_index.unwrap();
|
||||
let item_id = self.current_item_id.clone().unwrap(); // Clone to avoid borrow checker issues
|
||||
let content_index = 0; // Single content part for now
|
||||
|
||||
// Emit content_part.added before first delta
|
||||
if !self.has_emitted_content_part_added {
|
||||
let event =
|
||||
self.emit_content_part_added(output_index, &item_id, content_index);
|
||||
self.send_event(&event, tx)?;
|
||||
self.has_emitted_content_part_added = true;
|
||||
}
|
||||
|
||||
// Emit text delta
|
||||
let event =
|
||||
self.emit_text_delta(content, output_index, &item_id, content_index);
|
||||
self.send_event(&event, tx)?;
|
||||
}
|
||||
}
|
||||
|
||||
// Check for finish_reason to emit completion events
|
||||
if let Some(reason) = &choice.finish_reason {
|
||||
if reason == "stop" || reason == "length" {
|
||||
let output_index = self.current_message_output_index.unwrap();
|
||||
let item_id = self.current_item_id.clone().unwrap(); // Clone to avoid borrow checker issues
|
||||
let content_index = 0;
|
||||
|
||||
// Emit closing events
|
||||
if self.has_emitted_content_part_added {
|
||||
let event = self.emit_text_done(output_index, &item_id, content_index);
|
||||
self.send_event(&event, tx)?;
|
||||
let event =
|
||||
self.emit_content_part_done(output_index, &item_id, content_index);
|
||||
self.send_event(&event, tx)?;
|
||||
}
|
||||
|
||||
if self.has_emitted_output_item_added {
|
||||
// Build complete message item for output_item.done
|
||||
let item = json!({
|
||||
"id": item_id,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": self.accumulated_text.clone()
|
||||
}]
|
||||
});
|
||||
let event = self.emit_output_item_done(output_index, &item);
|
||||
self.send_event(&event, tx)?;
|
||||
}
|
||||
|
||||
// Mark item as completed
|
||||
self.complete_output_item(output_index);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(super) fn send_event(
|
||||
&self,
|
||||
event: &serde_json::Value,
|
||||
tx: &mpsc::UnboundedSender<Result<Bytes, std::io::Error>>,
|
||||
) -> Result<(), String> {
|
||||
let event_json = serde_json::to_string(event)
|
||||
.map_err(|e| format!("Failed to serialize event: {}", e))?;
|
||||
|
||||
if tx
|
||||
.send(Ok(Bytes::from(format!("data: {}\n\n", event_json))))
|
||||
.is_err()
|
||||
{
|
||||
return Err("Client disconnected".to_string());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
1114
sgl-router/src/routers/grpc/responses/tool_loop.rs
Normal file
1114
sgl-router/src/routers/grpc/responses/tool_loop.rs
Normal file
File diff suppressed because it is too large
Load Diff
18
sgl-router/src/routers/grpc/responses/types.rs
Normal file
18
sgl-router/src/routers/grpc/responses/types.rs
Normal file
@@ -0,0 +1,18 @@
|
||||
//! Type definitions for /v1/responses endpoint
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio::{sync::RwLock, task::JoinHandle};
|
||||
|
||||
/// Information stored for background tasks to enable end-to-end cancellation
|
||||
///
|
||||
/// This struct enables cancelling both the Rust task AND the Python scheduler processing.
|
||||
/// The client field is lazily initialized during pipeline execution.
|
||||
pub struct BackgroundTaskInfo {
|
||||
/// Tokio task handle for aborting the Rust task
|
||||
pub handle: JoinHandle<()>,
|
||||
/// gRPC request_id sent to Python scheduler (chatcmpl-* prefix)
|
||||
pub grpc_request_id: String,
|
||||
/// gRPC client for sending abort requests to Python (set after client acquisition)
|
||||
pub client: Arc<RwLock<Option<crate::grpc_client::SglangSchedulerClient>>>,
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
// gRPC Router Implementation
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use axum::{
|
||||
@@ -9,12 +9,20 @@ use axum::{
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::debug;
|
||||
|
||||
use super::{context::SharedComponents, pipeline::RequestPipeline};
|
||||
use super::{
|
||||
context::SharedComponents,
|
||||
pipeline::RequestPipeline,
|
||||
responses::{self, BackgroundTaskInfo},
|
||||
};
|
||||
use crate::{
|
||||
config::types::RetryConfig,
|
||||
core::WorkerRegistry,
|
||||
data_connector::{
|
||||
SharedConversationItemStorage, SharedConversationStorage, SharedResponseStorage,
|
||||
},
|
||||
policies::PolicyRegistry,
|
||||
protocols::{
|
||||
chat::ChatCompletionRequest,
|
||||
@@ -48,6 +56,14 @@ pub struct GrpcRouter {
|
||||
configured_tool_parser: Option<String>,
|
||||
pipeline: RequestPipeline,
|
||||
shared_components: Arc<SharedComponents>,
|
||||
// Storage backends for /v1/responses support
|
||||
response_storage: SharedResponseStorage,
|
||||
conversation_storage: SharedConversationStorage,
|
||||
conversation_item_storage: SharedConversationItemStorage,
|
||||
// Optional MCP manager for tool execution (enabled via SGLANG_MCP_CONFIG env var)
|
||||
mcp_manager: Option<Arc<crate::mcp::McpClientManager>>,
|
||||
// Background task handles for cancellation support (includes gRPC client for Python abort)
|
||||
background_tasks: Arc<RwLock<HashMap<String, BackgroundTaskInfo>>>,
|
||||
}
|
||||
|
||||
impl GrpcRouter {
|
||||
@@ -73,6 +89,31 @@ impl GrpcRouter {
|
||||
let worker_registry = ctx.worker_registry.clone();
|
||||
let policy_registry = ctx.policy_registry.clone();
|
||||
|
||||
// Extract storage backends from context
|
||||
let response_storage = ctx.response_storage.clone();
|
||||
let conversation_storage = ctx.conversation_storage.clone();
|
||||
let conversation_item_storage = ctx.conversation_item_storage.clone();
|
||||
|
||||
// Optional MCP manager activation via env var path (config-driven gate)
|
||||
let mcp_manager = match std::env::var("SGLANG_MCP_CONFIG").ok() {
|
||||
Some(path) if !path.trim().is_empty() => {
|
||||
match crate::mcp::McpConfig::from_file(&path).await {
|
||||
Ok(cfg) => match crate::mcp::McpClientManager::new(cfg).await {
|
||||
Ok(mgr) => Some(Arc::new(mgr)),
|
||||
Err(err) => {
|
||||
tracing::warn!("Failed to initialize MCP manager: {}", err);
|
||||
None
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
tracing::warn!("Failed to load MCP config from '{}': {}", path, err);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
|
||||
// Create shared components for pipeline
|
||||
let shared_components = Arc::new(SharedComponents {
|
||||
tokenizer: tokenizer.clone(),
|
||||
@@ -104,6 +145,11 @@ impl GrpcRouter {
|
||||
configured_tool_parser: ctx.configured_tool_parser.clone(),
|
||||
pipeline,
|
||||
shared_components,
|
||||
response_storage,
|
||||
conversation_storage,
|
||||
conversation_item_storage,
|
||||
mcp_manager,
|
||||
background_tasks: Arc::new(RwLock::new(HashMap::new())),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -217,24 +263,45 @@ impl RouterTrait for GrpcRouter {
|
||||
|
||||
async fn route_responses(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &ResponsesRequest,
|
||||
_model_id: Option<&str>,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &ResponsesRequest,
|
||||
model_id: Option<&str>,
|
||||
) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
// Use responses module for ALL requests (streaming and non-streaming)
|
||||
// Responses module handles:
|
||||
// - Request validation (previous_response_id XOR conversation)
|
||||
// - Loading response chain / conversation history from storage
|
||||
// - Conversion: ResponsesRequest → ChatCompletionRequest
|
||||
// - Execution through chat pipeline stages
|
||||
// - Conversion: ChatCompletionResponse → ResponsesResponse
|
||||
// - Response persistence
|
||||
// - MCP tool loop wrapper (future)
|
||||
responses::route_responses(
|
||||
&self.pipeline,
|
||||
Arc::new(body.clone()),
|
||||
headers.cloned(),
|
||||
model_id.map(|s| s.to_string()),
|
||||
self.shared_components.clone(),
|
||||
self.response_storage.clone(),
|
||||
self.conversation_storage.clone(),
|
||||
self.conversation_item_storage.clone(),
|
||||
self.background_tasks.clone(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn get_response(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_response_id: &str,
|
||||
response_id: &str,
|
||||
_params: &ResponsesGetParams,
|
||||
) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
responses::get_response_impl(&self.response_storage, response_id).await
|
||||
}
|
||||
|
||||
async fn cancel_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
async fn cancel_response(&self, _headers: Option<&HeaderMap>, response_id: &str) -> Response {
|
||||
responses::cancel_response_impl(&self.response_storage, &self.background_tasks, response_id)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn route_classify(
|
||||
|
||||
@@ -62,7 +62,10 @@ pub(super) async fn create_conversation(
|
||||
None => None,
|
||||
};
|
||||
|
||||
let new_conv = NewConversation { metadata };
|
||||
let new_conv = NewConversation {
|
||||
id: None, // Generate random ID (OpenAI behavior for POST /v1/conversations)
|
||||
metadata,
|
||||
};
|
||||
|
||||
match conversation_storage.create_conversation(new_conv).await {
|
||||
Ok(conversation) => {
|
||||
@@ -952,7 +955,7 @@ fn item_to_json(item: &crate::data_connector::conversation_items::ConversationIt
|
||||
// ============================================================================
|
||||
|
||||
/// Persist conversation items (delegates to persist_items_with_storages)
|
||||
pub(super) async fn persist_conversation_items(
|
||||
pub async fn persist_conversation_items(
|
||||
conversation_storage: Arc<dyn ConversationStorage>,
|
||||
item_storage: Arc<dyn ConversationItemStorage>,
|
||||
response_storage: Arc<dyn ResponseStorage>,
|
||||
|
||||
@@ -129,7 +129,7 @@ impl FunctionCallInProgress {
|
||||
// ============================================================================
|
||||
|
||||
/// Build a request-scoped MCP manager from request tools, if present.
|
||||
pub(super) async fn mcp_manager_from_request_tools(
|
||||
pub async fn mcp_manager_from_request_tools(
|
||||
tools: &[ResponseTool],
|
||||
) -> Option<Arc<McpClientManager>> {
|
||||
let tool = tools
|
||||
|
||||
@@ -7,8 +7,8 @@
|
||||
//! - Multi-turn tool execution loops
|
||||
//! - SSE (Server-Sent Events) streaming
|
||||
|
||||
mod conversations;
|
||||
mod mcp;
|
||||
pub mod conversations;
|
||||
pub mod mcp;
|
||||
mod responses;
|
||||
mod router;
|
||||
mod streaming;
|
||||
|
||||
Reference in New Issue
Block a user