[router][grpc] Support streaming for v1/chat/completions (#11179)
This commit is contained in:
@@ -4,7 +4,7 @@ use crate::config::types::RetryConfig;
|
||||
use crate::core::{WorkerRegistry, WorkerType};
|
||||
use crate::metrics::RouterMetrics;
|
||||
use crate::policies::PolicyRegistry;
|
||||
use crate::reasoning_parser::ParserFactory;
|
||||
use crate::reasoning_parser::ReasoningParserFactory;
|
||||
use crate::routers::RouterTrait;
|
||||
use crate::tokenizer::traits::Tokenizer;
|
||||
use crate::tool_parser::ToolParserFactory;
|
||||
@@ -24,7 +24,7 @@ pub struct GrpcPDRouter {
|
||||
worker_registry: Arc<WorkerRegistry>,
|
||||
policy_registry: Arc<PolicyRegistry>,
|
||||
tokenizer: Arc<dyn Tokenizer>,
|
||||
reasoning_parser_factory: ParserFactory,
|
||||
reasoning_parser_factory: ReasoningParserFactory,
|
||||
tool_parser_factory: ToolParserFactory,
|
||||
|
||||
dp_aware: bool,
|
||||
|
||||
@@ -7,10 +7,14 @@ use async_trait::async_trait;
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::Request,
|
||||
http::{HeaderMap, StatusCode},
|
||||
http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
use bytes::Bytes;
|
||||
use std::io;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::config::types::RetryConfig;
|
||||
@@ -21,11 +25,12 @@ use crate::policies::PolicyRegistry;
|
||||
use crate::protocols::spec::ChatMessage;
|
||||
use crate::protocols::spec::{
|
||||
ChatChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse,
|
||||
CompletionRequest, EmbeddingRequest, FunctionCallResponse, GenerateRequest, RerankRequest,
|
||||
ResponsesGetParams, ResponsesRequest, StringOrArray, Tool, ToolCall, ToolChoice,
|
||||
ChatCompletionStreamResponse, ChatMessageDelta, ChatStreamChoice, CompletionRequest,
|
||||
EmbeddingRequest, FunctionCallDelta, FunctionCallResponse, GenerateRequest, RerankRequest,
|
||||
ResponsesGetParams, ResponsesRequest, StringOrArray, Tool, ToolCall, ToolCallDelta, ToolChoice,
|
||||
ToolChoiceValue, Usage,
|
||||
};
|
||||
use crate::reasoning_parser::ParserFactory;
|
||||
use crate::reasoning_parser::{ParserResult, ReasoningParserFactory};
|
||||
use crate::routers::RouterTrait;
|
||||
use crate::server::AppContext;
|
||||
use crate::tokenizer::chat_template::{ChatTemplateContentFormat, ChatTemplateParams};
|
||||
@@ -34,7 +39,7 @@ use crate::tokenizer::stop::{
|
||||
};
|
||||
use crate::tokenizer::traits::Tokenizer;
|
||||
use crate::tokenizer::HuggingFaceTokenizer;
|
||||
use crate::tool_parser::ToolParserFactory;
|
||||
use crate::tool_parser::{StreamingParseResult, ToolParserFactory};
|
||||
use proto::generate_response::Response::{Chunk, Complete, Error};
|
||||
use serde_json::{json, Map, Value};
|
||||
use std::time::{Instant, SystemTime, UNIX_EPOCH};
|
||||
@@ -50,12 +55,13 @@ pub struct ProcessedMessages {
|
||||
}
|
||||
|
||||
/// gRPC router implementation for SGLang
|
||||
#[derive(Clone)]
|
||||
#[allow(dead_code)]
|
||||
pub struct GrpcRouter {
|
||||
worker_registry: Arc<WorkerRegistry>,
|
||||
policy_registry: Arc<PolicyRegistry>,
|
||||
tokenizer: Arc<dyn Tokenizer>,
|
||||
reasoning_parser_factory: ParserFactory,
|
||||
reasoning_parser_factory: ReasoningParserFactory,
|
||||
tool_parser_factory: ToolParserFactory,
|
||||
dp_aware: bool,
|
||||
api_key: Option<String>,
|
||||
@@ -776,10 +782,11 @@ impl GrpcRouter {
|
||||
}
|
||||
|
||||
/// Parse tool calls using model-specific parser
|
||||
async fn parse_with_model_parser(
|
||||
async fn parse_tool_calls(
|
||||
&self,
|
||||
processed_text: &str,
|
||||
model: &str,
|
||||
history_tool_calls_count: usize,
|
||||
) -> (Option<Vec<ToolCall>>, String) {
|
||||
// Get pooled parser for this model
|
||||
let pooled_parser = self.tool_parser_factory.get_pooled(model);
|
||||
@@ -810,16 +817,26 @@ impl GrpcRouter {
|
||||
|
||||
let spec_tool_calls = parsed_tool_calls
|
||||
.into_iter()
|
||||
.map(|tc| ToolCall {
|
||||
id: tc.id,
|
||||
tool_type: "function".to_string(),
|
||||
function: FunctionCallResponse {
|
||||
name: tc.function.name,
|
||||
arguments: Some(
|
||||
serde_json::to_string(&tc.function.arguments)
|
||||
.unwrap_or_else(|_| "{}".to_string()),
|
||||
),
|
||||
},
|
||||
.enumerate()
|
||||
.map(|(index, tc)| {
|
||||
// Generate ID for this tool call
|
||||
let id = Self::generate_tool_call_id(
|
||||
model,
|
||||
&tc.function.name,
|
||||
index,
|
||||
history_tool_calls_count,
|
||||
);
|
||||
ToolCall {
|
||||
id,
|
||||
tool_type: "function".to_string(),
|
||||
function: FunctionCallResponse {
|
||||
name: tc.function.name,
|
||||
arguments: Some(
|
||||
serde_json::to_string(&tc.function.arguments)
|
||||
.unwrap_or_else(|_| "{}".to_string()),
|
||||
),
|
||||
},
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
(Some(spec_tool_calls), normal_text)
|
||||
@@ -920,6 +937,47 @@ impl GrpcRouter {
|
||||
builder.build()
|
||||
}
|
||||
|
||||
/// Count the number of tool calls in the request message history
|
||||
/// This is used for KimiK2 format which needs globally unique indices
|
||||
fn get_history_tool_calls_count(request: &ChatCompletionRequest) -> usize {
|
||||
request
|
||||
.messages
|
||||
.iter()
|
||||
.filter_map(|msg| {
|
||||
if let ChatMessage::Assistant { tool_calls, .. } = msg {
|
||||
tool_calls.as_ref().map(|calls| calls.len())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.sum()
|
||||
}
|
||||
|
||||
/// Generate a tool call ID based on model format
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `model` - Model name to determine ID format
|
||||
/// * `tool_name` - Name of the tool being called
|
||||
/// * `tool_index` - Index of this tool call within the current message
|
||||
/// * `history_count` - Number of tool calls in previous messages
|
||||
///
|
||||
/// # Returns
|
||||
/// A unique ID string. KimiK2 uses `functions.{name}:{global_index}`, others use `call_{uuid}`
|
||||
fn generate_tool_call_id(
|
||||
model: &str,
|
||||
tool_name: &str,
|
||||
tool_index: usize,
|
||||
history_count: usize,
|
||||
) -> String {
|
||||
if model.to_lowercase().contains("kimi") {
|
||||
// KimiK2 format: functions.{name}:{global_index}
|
||||
format!("functions.{}:{}", tool_name, history_count + tool_index)
|
||||
} else {
|
||||
// Standard OpenAI format: call_{24-char-uuid}
|
||||
format!("call_{}", &Uuid::new_v4().simple().to_string()[..24])
|
||||
}
|
||||
}
|
||||
|
||||
/// Process a chunk of tokens through the stop decoder
|
||||
fn process_chunk_tokens(
|
||||
stop_decoder: &mut StopSequenceDecoder,
|
||||
@@ -953,6 +1011,230 @@ impl GrpcRouter {
|
||||
(chunk_text, false) // Return text and continue processing
|
||||
}
|
||||
|
||||
/// Helper: Process reasoning content in streaming mode
|
||||
/// Returns (modified_delta, optional_reasoning_chunk)
|
||||
fn process_reasoning_stream(
|
||||
&self,
|
||||
delta: &str,
|
||||
index: u32,
|
||||
reasoning_parsers: &mut HashMap<
|
||||
u32,
|
||||
Arc<std::sync::Mutex<Box<dyn crate::reasoning_parser::ReasoningParser>>>,
|
||||
>,
|
||||
request_id: &str,
|
||||
model: &str,
|
||||
created: u64,
|
||||
) -> (String, Option<ChatCompletionStreamResponse>) {
|
||||
// Get or create parser for this index
|
||||
reasoning_parsers
|
||||
.entry(index)
|
||||
.or_insert_with(|| self.reasoning_parser_factory.get_pooled(model));
|
||||
|
||||
if let Some(pooled_parser) = reasoning_parsers.get(&index) {
|
||||
let parse_result = {
|
||||
let mut parser = pooled_parser.lock().unwrap();
|
||||
parser.parse_reasoning_streaming_incremental(delta)
|
||||
};
|
||||
|
||||
match parse_result {
|
||||
Ok(ParserResult {
|
||||
reasoning_text,
|
||||
normal_text,
|
||||
}) => {
|
||||
let chunk = if !reasoning_text.is_empty() {
|
||||
Some(ChatCompletionStreamResponse {
|
||||
id: request_id.to_string(),
|
||||
object: "chat.completion.chunk".to_string(),
|
||||
created,
|
||||
model: model.to_string(),
|
||||
system_fingerprint: None,
|
||||
choices: vec![ChatStreamChoice {
|
||||
index,
|
||||
delta: ChatMessageDelta {
|
||||
role: Some("assistant".to_string()),
|
||||
content: None,
|
||||
tool_calls: None,
|
||||
reasoning_content: Some(reasoning_text),
|
||||
},
|
||||
logprobs: None,
|
||||
finish_reason: None,
|
||||
matched_stop: None,
|
||||
}],
|
||||
usage: None,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
return (normal_text, chunk);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Reasoning parsing error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
(delta.to_string(), None)
|
||||
}
|
||||
|
||||
/// Helper: Process tool calls in streaming mode
|
||||
/// Returns (should_skip_content, chunks_to_emit)
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn process_tool_calls_stream(
|
||||
&self,
|
||||
delta: &str,
|
||||
index: u32,
|
||||
tool_parsers: &mut HashMap<
|
||||
u32,
|
||||
Arc<tokio::sync::Mutex<Box<dyn crate::tool_parser::ToolParser>>>,
|
||||
>,
|
||||
has_tool_calls: &mut HashMap<u32, bool>,
|
||||
tools: &[crate::protocols::spec::Tool],
|
||||
request_id: &str,
|
||||
model: &str,
|
||||
created: u64,
|
||||
history_tool_calls_count: usize,
|
||||
) -> (bool, Vec<ChatCompletionStreamResponse>) {
|
||||
let mut chunks = Vec::new();
|
||||
|
||||
// Get or create parser for this index
|
||||
tool_parsers
|
||||
.entry(index)
|
||||
.or_insert_with(|| self.tool_parser_factory.get_pooled(model));
|
||||
|
||||
if let Some(pooled_parser) = tool_parsers.get(&index) {
|
||||
let mut parser = pooled_parser.lock().await;
|
||||
match parser.parse_incremental(delta, tools).await {
|
||||
Ok(StreamingParseResult { normal_text, calls }) => {
|
||||
// Emit normal text if present
|
||||
if !normal_text.is_empty() {
|
||||
chunks.push(ChatCompletionStreamResponse {
|
||||
id: request_id.to_string(),
|
||||
object: "chat.completion.chunk".to_string(),
|
||||
created,
|
||||
model: model.to_string(),
|
||||
system_fingerprint: None,
|
||||
choices: vec![ChatStreamChoice {
|
||||
index,
|
||||
delta: ChatMessageDelta {
|
||||
role: Some("assistant".to_string()),
|
||||
content: Some(normal_text),
|
||||
tool_calls: None,
|
||||
reasoning_content: None,
|
||||
},
|
||||
logprobs: None,
|
||||
finish_reason: None,
|
||||
matched_stop: None,
|
||||
}],
|
||||
usage: None,
|
||||
});
|
||||
}
|
||||
|
||||
// Emit tool call chunks
|
||||
for tool_call_item in calls {
|
||||
has_tool_calls.insert(index, true);
|
||||
|
||||
let tool_call_id = if let Some(ref name) = tool_call_item.name {
|
||||
Some(Self::generate_tool_call_id(
|
||||
model,
|
||||
name,
|
||||
tool_call_item.tool_index,
|
||||
history_tool_calls_count,
|
||||
))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let tool_call_delta = ToolCallDelta {
|
||||
index: tool_call_item.tool_index as u32,
|
||||
id: tool_call_id,
|
||||
tool_type: if tool_call_item.name.is_some() {
|
||||
Some("function".to_string())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
function: Some(FunctionCallDelta {
|
||||
name: tool_call_item.name,
|
||||
arguments: if !tool_call_item.parameters.is_empty() {
|
||||
Some(tool_call_item.parameters)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
}),
|
||||
};
|
||||
|
||||
chunks.push(ChatCompletionStreamResponse {
|
||||
id: request_id.to_string(),
|
||||
object: "chat.completion.chunk".to_string(),
|
||||
created,
|
||||
model: model.to_string(),
|
||||
system_fingerprint: None,
|
||||
choices: vec![ChatStreamChoice {
|
||||
index,
|
||||
delta: ChatMessageDelta {
|
||||
role: Some("assistant".to_string()),
|
||||
content: None,
|
||||
tool_calls: Some(vec![tool_call_delta]),
|
||||
reasoning_content: None,
|
||||
},
|
||||
logprobs: None,
|
||||
finish_reason: None,
|
||||
matched_stop: None,
|
||||
}],
|
||||
usage: None,
|
||||
});
|
||||
}
|
||||
|
||||
// If we emitted chunks, skip regular content
|
||||
return (!chunks.is_empty(), chunks);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Tool call parsing error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
(false, chunks)
|
||||
}
|
||||
|
||||
/// Helper: Create content chunk
|
||||
fn create_content_chunk(
|
||||
content: String,
|
||||
index: u32,
|
||||
request_id: &str,
|
||||
model: &str,
|
||||
created: u64,
|
||||
logprobs: Option<crate::protocols::spec::ChatLogProbs>,
|
||||
) -> ChatCompletionStreamResponse {
|
||||
ChatCompletionStreamResponse {
|
||||
id: request_id.to_string(),
|
||||
object: "chat.completion.chunk".to_string(),
|
||||
created,
|
||||
model: model.to_string(),
|
||||
system_fingerprint: None,
|
||||
choices: vec![ChatStreamChoice {
|
||||
index,
|
||||
delta: ChatMessageDelta {
|
||||
role: Some("assistant".to_string()),
|
||||
content: Some(content),
|
||||
tool_calls: None,
|
||||
reasoning_content: None,
|
||||
},
|
||||
logprobs,
|
||||
finish_reason: None,
|
||||
matched_stop: None,
|
||||
}],
|
||||
usage: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper: Format response as SSE chunk
|
||||
fn format_sse_chunk(response: &ChatCompletionStreamResponse) -> String {
|
||||
format!(
|
||||
"data: {}\n\n",
|
||||
serde_json::to_string(response).unwrap_or_default()
|
||||
)
|
||||
}
|
||||
|
||||
/// Submit request and handle streaming response for chat completions route
|
||||
async fn handle_streaming_chat(
|
||||
&self,
|
||||
@@ -960,14 +1242,13 @@ impl GrpcRouter {
|
||||
request: proto::GenerateRequest,
|
||||
original_request: &ChatCompletionRequest,
|
||||
) -> Response {
|
||||
let mut stop_decoder = self.create_stop_decoder(
|
||||
original_request.stop.as_ref(),
|
||||
original_request.stop_token_ids.as_ref(),
|
||||
original_request.skip_special_tokens,
|
||||
original_request.no_stop_trim,
|
||||
);
|
||||
let request_id = request.request_id.clone();
|
||||
let model = original_request.model.clone();
|
||||
|
||||
// Process streaming tokens
|
||||
// Create channel for SSE streaming
|
||||
let (tx, rx) = mpsc::unbounded_channel::<Result<Bytes, io::Error>>();
|
||||
|
||||
// Start the gRPC stream
|
||||
let mut grpc_stream = match client.generate(request).await {
|
||||
Ok(stream) => stream,
|
||||
Err(e) => {
|
||||
@@ -980,49 +1261,414 @@ impl GrpcRouter {
|
||||
}
|
||||
};
|
||||
|
||||
let mut decoded_text = String::new();
|
||||
let stop_params = (
|
||||
original_request.stop.clone(),
|
||||
original_request.stop_token_ids.clone(),
|
||||
original_request.skip_special_tokens,
|
||||
original_request.no_stop_trim,
|
||||
);
|
||||
|
||||
// Spawn processing task
|
||||
let self_clone = self.clone();
|
||||
let original_request_clone = original_request.clone();
|
||||
tokio::spawn(async move {
|
||||
let result = Self::process_streaming_chunks(
|
||||
&self_clone,
|
||||
&mut grpc_stream,
|
||||
request_id,
|
||||
model,
|
||||
stop_params,
|
||||
original_request_clone,
|
||||
&tx,
|
||||
)
|
||||
.await;
|
||||
|
||||
if let Err(e) = result {
|
||||
let error_chunk = format!(
|
||||
"data: {}\n\n",
|
||||
json!({
|
||||
"error": {
|
||||
"message": e,
|
||||
"type": "internal_error"
|
||||
}
|
||||
})
|
||||
);
|
||||
let _ = tx.send(Ok(Bytes::from(error_chunk)));
|
||||
}
|
||||
|
||||
// Send DONE marker
|
||||
let _ = tx.send(Ok(Bytes::from("data: [DONE]\n\n")));
|
||||
});
|
||||
|
||||
// Create response with SSE headers
|
||||
let stream = UnboundedReceiverStream::new(rx);
|
||||
let mut response = Response::new(Body::from_stream(stream));
|
||||
*response.status_mut() = StatusCode::OK;
|
||||
response
|
||||
.headers_mut()
|
||||
.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
|
||||
response
|
||||
.headers_mut()
|
||||
.insert("Cache-Control", HeaderValue::from_static("no-cache"));
|
||||
response
|
||||
.headers_mut()
|
||||
.insert("Connection", HeaderValue::from_static("keep-alive"));
|
||||
response
|
||||
}
|
||||
|
||||
/// Process streaming chunks and send SSE events
|
||||
async fn process_streaming_chunks(
|
||||
router: &GrpcRouter,
|
||||
grpc_stream: &mut (impl tokio_stream::Stream<Item = Result<proto::GenerateResponse, tonic::Status>>
|
||||
+ Unpin),
|
||||
request_id: String,
|
||||
model: String,
|
||||
stop_params: (Option<StringOrArray>, Option<Vec<u32>>, bool, bool),
|
||||
original_request: ChatCompletionRequest,
|
||||
tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>,
|
||||
) -> Result<(), String> {
|
||||
// Extract request parameters
|
||||
let separate_reasoning = original_request.separate_reasoning;
|
||||
let tool_choice = &original_request.tool_choice;
|
||||
let tools = &original_request.tools;
|
||||
let history_tool_calls_count = Self::get_history_tool_calls_count(&original_request);
|
||||
let stream_options = &original_request.stream_options;
|
||||
|
||||
// Phase 1: Initialize state tracking (per-index for n>1 support)
|
||||
let mut is_firsts: HashMap<u32, bool> = HashMap::new();
|
||||
let mut stream_buffers: HashMap<u32, String> = HashMap::new();
|
||||
let mut finish_reasons: HashMap<u32, String> = HashMap::new();
|
||||
let mut matched_stops: HashMap<u32, Option<Value>> = HashMap::new();
|
||||
let mut prompt_tokens: HashMap<u32, u32> = HashMap::new();
|
||||
let mut completion_tokens: HashMap<u32, u32> = HashMap::new();
|
||||
let mut cached_tokens: HashMap<u32, u32> = HashMap::new();
|
||||
|
||||
// Parser state (lazy initialization per index)
|
||||
type PooledReasoningParser =
|
||||
Arc<std::sync::Mutex<Box<dyn crate::reasoning_parser::ReasoningParser>>>;
|
||||
let mut reasoning_parsers: HashMap<u32, PooledReasoningParser> = HashMap::new();
|
||||
|
||||
type PooledToolParser = Arc<tokio::sync::Mutex<Box<dyn crate::tool_parser::ToolParser>>>;
|
||||
let mut tool_parsers: HashMap<u32, PooledToolParser> = HashMap::new();
|
||||
let mut has_tool_calls: HashMap<u32, bool> = HashMap::new();
|
||||
|
||||
// Create stop decoder
|
||||
let (stop, stop_token_ids, skip_special_tokens, no_stop_trim) = stop_params;
|
||||
let mut stop_decoder = router.create_stop_decoder(
|
||||
stop.as_ref(),
|
||||
stop_token_ids.as_ref(),
|
||||
skip_special_tokens,
|
||||
no_stop_trim,
|
||||
);
|
||||
|
||||
let created = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
|
||||
// Phase 2: Main streaming loop
|
||||
while let Some(response) = grpc_stream.next().await {
|
||||
let gen_response = match response {
|
||||
Ok(resp) => resp,
|
||||
Err(e) => {
|
||||
error!("Stream error: {}", e);
|
||||
break;
|
||||
}
|
||||
};
|
||||
let gen_response = response.map_err(|e| format!("Stream error: {}", e))?;
|
||||
|
||||
match gen_response.response {
|
||||
Some(Chunk(chunk)) => {
|
||||
// Process tokens and check if we should stop
|
||||
let (chunk_text, should_stop) =
|
||||
let index = chunk.index;
|
||||
|
||||
// Process tokens through stop decoder
|
||||
let (chunk_text, _should_stop) =
|
||||
Self::process_chunk_tokens(&mut stop_decoder, &chunk.token_ids);
|
||||
decoded_text.push_str(&chunk_text);
|
||||
if should_stop {
|
||||
break;
|
||||
|
||||
if chunk_text.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Process logprobs if present
|
||||
let choice_logprobs = if let Some(ref proto_logprobs) = chunk.output_logprobs {
|
||||
match router.convert_proto_to_openai_logprobs(proto_logprobs) {
|
||||
Ok(logprobs) => Some(logprobs),
|
||||
Err(e) => {
|
||||
warn!("Failed to process logprobs: {}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Initialize stream buffer if first time
|
||||
let stream_buffer = stream_buffers.entry(index).or_default();
|
||||
|
||||
// Send first chunk with role
|
||||
if is_firsts.get(&index).copied().unwrap_or(true) {
|
||||
let first_chunk = ChatCompletionStreamResponse {
|
||||
id: request_id.clone(),
|
||||
object: "chat.completion.chunk".to_string(),
|
||||
created,
|
||||
model: model.clone(),
|
||||
system_fingerprint: None,
|
||||
choices: vec![ChatStreamChoice {
|
||||
index,
|
||||
delta: ChatMessageDelta {
|
||||
role: Some("assistant".to_string()),
|
||||
content: None,
|
||||
tool_calls: None,
|
||||
reasoning_content: None,
|
||||
},
|
||||
logprobs: None,
|
||||
finish_reason: None,
|
||||
matched_stop: None,
|
||||
}],
|
||||
usage: None,
|
||||
};
|
||||
tx.send(Ok(Bytes::from(Self::format_sse_chunk(&first_chunk))))
|
||||
.map_err(|_| "Failed to send first chunk".to_string())?;
|
||||
is_firsts.insert(index, false);
|
||||
}
|
||||
|
||||
// Calculate delta
|
||||
let mut delta = chunk_text;
|
||||
stream_buffer.push_str(&delta);
|
||||
|
||||
// Reasoning content handling
|
||||
if separate_reasoning {
|
||||
let (normal_text, reasoning_chunk) = router.process_reasoning_stream(
|
||||
&delta,
|
||||
index,
|
||||
&mut reasoning_parsers,
|
||||
&request_id,
|
||||
&model,
|
||||
created,
|
||||
);
|
||||
if let Some(chunk) = reasoning_chunk {
|
||||
tx.send(Ok(Bytes::from(Self::format_sse_chunk(&chunk))))
|
||||
.map_err(|_| "Failed to send reasoning chunk".to_string())?;
|
||||
}
|
||||
delta = normal_text;
|
||||
}
|
||||
|
||||
// Tool call handling
|
||||
let tool_choice_enabled =
|
||||
!matches!(tool_choice, Some(ToolChoice::Value(ToolChoiceValue::None)));
|
||||
|
||||
if tool_choice_enabled && tools.is_some() {
|
||||
let (should_skip, tool_chunks) = router
|
||||
.process_tool_calls_stream(
|
||||
&delta,
|
||||
index,
|
||||
&mut tool_parsers,
|
||||
&mut has_tool_calls,
|
||||
tools.as_ref().unwrap(),
|
||||
&request_id,
|
||||
&model,
|
||||
created,
|
||||
history_tool_calls_count,
|
||||
)
|
||||
.await;
|
||||
|
||||
for chunk in tool_chunks {
|
||||
tx.send(Ok(Bytes::from(Self::format_sse_chunk(&chunk))))
|
||||
.map_err(|_| "Failed to send tool call chunk".to_string())?;
|
||||
}
|
||||
|
||||
if should_skip {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Regular content emission
|
||||
if !delta.is_empty() {
|
||||
let content_chunk = Self::create_content_chunk(
|
||||
delta,
|
||||
index,
|
||||
&request_id,
|
||||
&model,
|
||||
created,
|
||||
choice_logprobs,
|
||||
);
|
||||
tx.send(Ok(Bytes::from(Self::format_sse_chunk(&content_chunk))))
|
||||
.map_err(|_| "Failed to send content chunk".to_string())?;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
Some(Complete(_complete)) => {
|
||||
Some(Complete(complete)) => {
|
||||
// Flush any remaining text
|
||||
if let SequenceDecoderOutput::Text(text) = stop_decoder.flush() {
|
||||
if !text.is_empty() {
|
||||
decoded_text.push_str(&text);
|
||||
debug!("Flushed text: {}", text);
|
||||
let index = complete.index;
|
||||
let stream_buffer = stream_buffers.entry(index).or_default();
|
||||
stream_buffer.push_str(&text);
|
||||
|
||||
let content_chunk = ChatCompletionStreamResponse {
|
||||
id: request_id.clone(),
|
||||
object: "chat.completion.chunk".to_string(),
|
||||
created,
|
||||
model: model.clone(),
|
||||
system_fingerprint: None,
|
||||
choices: vec![ChatStreamChoice {
|
||||
index,
|
||||
delta: ChatMessageDelta {
|
||||
role: Some("assistant".to_string()),
|
||||
content: Some(text),
|
||||
tool_calls: None,
|
||||
reasoning_content: None,
|
||||
},
|
||||
logprobs: None,
|
||||
finish_reason: None,
|
||||
matched_stop: None,
|
||||
}],
|
||||
usage: None,
|
||||
};
|
||||
|
||||
let sse_chunk = serde_json::to_string(&content_chunk)
|
||||
.map_err(|e| format!("Failed to serialize content chunk: {}", e))?;
|
||||
tx.send(Ok(Bytes::from(format!("data: {}\n\n", sse_chunk))))
|
||||
.map_err(|_| "Failed to send flushed content".to_string())?;
|
||||
}
|
||||
}
|
||||
|
||||
// Store metadata
|
||||
let index = complete.index;
|
||||
prompt_tokens.insert(index, complete.prompt_tokens as u32);
|
||||
completion_tokens.insert(index, complete.completion_tokens as u32);
|
||||
cached_tokens.insert(index, complete.cached_tokens as u32);
|
||||
finish_reasons.insert(index, complete.finish_reason.clone());
|
||||
|
||||
// Extract matched_stop
|
||||
let matched_stop_value = match &complete.matched_stop {
|
||||
Some(proto::generate_complete::MatchedStop::MatchedTokenId(token_id)) => {
|
||||
Some(Value::Number(serde_json::Number::from(*token_id)))
|
||||
}
|
||||
Some(proto::generate_complete::MatchedStop::MatchedStopStr(stop_str)) => {
|
||||
Some(Value::String(stop_str.clone()))
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
matched_stops.insert(index, matched_stop_value);
|
||||
|
||||
break;
|
||||
}
|
||||
Some(Error(error)) => {
|
||||
error!("Generation error: {}", error.message);
|
||||
break;
|
||||
return Err(error.message);
|
||||
}
|
||||
None => continue,
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Replace with proper SSE streaming response
|
||||
// For now, return the complete decoded text
|
||||
(StatusCode::OK, format!("Decoded text: {}", decoded_text)).into_response()
|
||||
// Phase 3: Check unstreamed tool args
|
||||
// Check if parsers have any remaining arguments that haven't been streamed yet
|
||||
for (index, parser) in &tool_parsers {
|
||||
let parser_guard = parser.lock().await;
|
||||
if let Some(unstreamed_items) = parser_guard.get_unstreamed_tool_args() {
|
||||
for tool_call_item in unstreamed_items {
|
||||
let tool_call_delta = ToolCallDelta {
|
||||
index: tool_call_item.tool_index as u32,
|
||||
id: None,
|
||||
tool_type: None, // No type for argument deltas
|
||||
function: Some(FunctionCallDelta {
|
||||
name: None, // No name for argument deltas
|
||||
arguments: if !tool_call_item.parameters.is_empty() {
|
||||
Some(tool_call_item.parameters)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
}),
|
||||
};
|
||||
|
||||
let tool_chunk = ChatCompletionStreamResponse {
|
||||
id: request_id.clone(),
|
||||
object: "chat.completion.chunk".to_string(),
|
||||
created,
|
||||
model: model.clone(),
|
||||
system_fingerprint: None,
|
||||
choices: vec![ChatStreamChoice {
|
||||
index: *index,
|
||||
delta: ChatMessageDelta {
|
||||
role: Some("assistant".to_string()),
|
||||
content: None,
|
||||
tool_calls: Some(vec![tool_call_delta]),
|
||||
reasoning_content: None,
|
||||
},
|
||||
logprobs: None,
|
||||
finish_reason: None,
|
||||
matched_stop: None,
|
||||
}],
|
||||
usage: None,
|
||||
};
|
||||
|
||||
let sse_chunk = serde_json::to_string(&tool_chunk)
|
||||
.map_err(|e| format!("Failed to serialize tool chunk: {}", e))?;
|
||||
tx.send(Ok(Bytes::from(format!("data: {}\n\n", sse_chunk))))
|
||||
.map_err(|_| "Failed to send unstreamed tool args".to_string())?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 4: Finish reason chunks
|
||||
for (index, finish_reason) in finish_reasons.iter() {
|
||||
let final_finish_reason =
|
||||
if has_tool_calls.get(index).copied().unwrap_or(false) && finish_reason == "stop" {
|
||||
"tool_calls".to_string()
|
||||
} else {
|
||||
finish_reason.clone()
|
||||
};
|
||||
|
||||
let matched_stop_value = matched_stops.get(index).and_then(|v| v.clone());
|
||||
|
||||
let finish_chunk = ChatCompletionStreamResponse {
|
||||
id: request_id.clone(),
|
||||
object: "chat.completion.chunk".to_string(),
|
||||
created,
|
||||
model: model.clone(),
|
||||
system_fingerprint: None,
|
||||
choices: vec![ChatStreamChoice {
|
||||
index: *index,
|
||||
delta: ChatMessageDelta {
|
||||
role: Some("assistant".to_string()),
|
||||
content: None,
|
||||
tool_calls: None,
|
||||
reasoning_content: None,
|
||||
},
|
||||
logprobs: None,
|
||||
finish_reason: Some(final_finish_reason),
|
||||
matched_stop: matched_stop_value,
|
||||
}],
|
||||
usage: None,
|
||||
};
|
||||
|
||||
let sse_chunk = serde_json::to_string(&finish_chunk)
|
||||
.map_err(|e| format!("Failed to serialize finish chunk: {}", e))?;
|
||||
tx.send(Ok(Bytes::from(format!("data: {}\n\n", sse_chunk))))
|
||||
.map_err(|_| "Failed to send finish chunk".to_string())?;
|
||||
}
|
||||
|
||||
// Phase 5: Usage chunk
|
||||
if let Some(stream_opts) = stream_options {
|
||||
if stream_opts.include_usage.unwrap_or(false) {
|
||||
let total_prompt: u32 = prompt_tokens.values().sum();
|
||||
let total_completion: u32 = completion_tokens.values().sum();
|
||||
|
||||
let usage_chunk = ChatCompletionStreamResponse {
|
||||
id: request_id.clone(),
|
||||
object: "chat.completion.chunk".to_string(),
|
||||
created,
|
||||
model: model.clone(),
|
||||
system_fingerprint: None,
|
||||
choices: vec![],
|
||||
usage: Some(Usage {
|
||||
prompt_tokens: total_prompt,
|
||||
completion_tokens: total_completion,
|
||||
total_tokens: total_prompt + total_completion,
|
||||
completion_tokens_details: None,
|
||||
}),
|
||||
};
|
||||
|
||||
let sse_chunk = serde_json::to_string(&usage_chunk)
|
||||
.map_err(|e| format!("Failed to serialize usage chunk: {}", e))?;
|
||||
tx.send(Ok(Bytes::from(format!("data: {}\n\n", sse_chunk))))
|
||||
.map_err(|_| "Failed to send usage chunk".to_string())?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Submit request and handle non-streaming response for chat completions route
|
||||
@@ -1082,10 +1728,17 @@ impl GrpcRouter {
|
||||
}
|
||||
|
||||
// Process each response into a ChatChoice
|
||||
let history_tool_calls_count = Self::get_history_tool_calls_count(original_request);
|
||||
let mut choices = Vec::new();
|
||||
for (index, complete) in all_responses.iter().enumerate() {
|
||||
match self
|
||||
.process_single_choice(complete, index, original_request, &mut stop_decoder)
|
||||
.process_single_choice(
|
||||
complete,
|
||||
index,
|
||||
original_request,
|
||||
&mut stop_decoder,
|
||||
history_tool_calls_count,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(choice) => choices.push(choice),
|
||||
@@ -1216,11 +1869,12 @@ impl GrpcRouter {
|
||||
decoded_text.push_str(&t);
|
||||
}
|
||||
|
||||
let output_ids = complete.output_ids.clone();
|
||||
let output_ids = std::mem::take(&mut complete.output_ids);
|
||||
let finish_reason = std::mem::take(&mut complete.finish_reason);
|
||||
|
||||
// Build base meta_info using json! macro
|
||||
let mut meta_info = json!({
|
||||
"finish_reason": complete.finish_reason.clone(),
|
||||
"finish_reason": finish_reason,
|
||||
"prompt_tokens": complete.prompt_tokens,
|
||||
"completion_tokens": complete.completion_tokens,
|
||||
"cached_tokens": complete.cached_tokens,
|
||||
@@ -1269,9 +1923,13 @@ impl GrpcRouter {
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Build ChatLogProbsContent for each token
|
||||
for (i, &logprob) in proto_logprobs.token_logprobs.iter().enumerate() {
|
||||
let token_text = token_texts.get(i).cloned().unwrap_or_default();
|
||||
// Build ChatLogProbsContent for each token (consume iterator to avoid clones)
|
||||
for (i, (&logprob, token_text)) in proto_logprobs
|
||||
.token_logprobs
|
||||
.iter()
|
||||
.zip(token_texts.into_iter())
|
||||
.enumerate()
|
||||
{
|
||||
let bytes = Some(token_text.as_bytes().to_vec());
|
||||
|
||||
// Build top_logprobs for this position
|
||||
@@ -1324,6 +1982,7 @@ impl GrpcRouter {
|
||||
index: usize,
|
||||
original_request: &ChatCompletionRequest,
|
||||
stop_decoder: &mut StopSequenceDecoder,
|
||||
history_tool_calls_count: usize,
|
||||
) -> Result<ChatChoice, String> {
|
||||
stop_decoder.reset();
|
||||
// Decode tokens
|
||||
@@ -1401,7 +2060,11 @@ impl GrpcRouter {
|
||||
self.parse_json_schema_response(&processed_text, &original_request.tool_choice);
|
||||
} else {
|
||||
(tool_calls, processed_text) = self
|
||||
.parse_with_model_parser(&processed_text, &original_request.model)
|
||||
.parse_tool_calls(
|
||||
&processed_text,
|
||||
&original_request.model,
|
||||
history_tool_calls_count,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
@@ -1686,7 +2349,6 @@ mod tests {
|
||||
content: Some("Assistant response".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
function_call: None,
|
||||
reasoning_content: None,
|
||||
}];
|
||||
|
||||
|
||||
Reference in New Issue
Block a user