[router][grpc] Support E2E non-stream chat completions (#10980)

This commit is contained in:
Chang Su
2025-09-26 22:02:06 -07:00
committed by GitHub
parent bd95944cf6
commit 37f3325b06
8 changed files with 325 additions and 136 deletions

View File

@@ -185,20 +185,8 @@ message GenerateComplete {
// Final output
repeated uint32 output_ids = 1;
// Finish reason
enum FinishReason {
// The model generated a stop sequence.
STOP = 0;
// The model reached the maximum generation length.
LENGTH = 1;
// The model generated an end-of-sequence (EOS) token.
EOS_TOKEN = 2;
// The model generated a user-provided stop string.
STOP_STR = 3;
// The request was aborted by the user or system.
ABORT = 4;
}
FinishReason finish_reason = 2;
// Finish reason as OpenAI-compatible string ("stop", "length", "abort")
string finish_reason = 2;
// Token usage counts
int32 prompt_tokens = 3;
@@ -210,6 +198,12 @@ message GenerateComplete {
// All hidden states if requested
repeated HiddenStates all_hidden_states = 7;
// Matched stop information (for stop sequences)
oneof matched_stop {
uint32 matched_token_id = 8;
string matched_stop_str = 9;
}
}
message GenerateError {

View File

@@ -423,10 +423,25 @@ pub struct ChatCompletionResponse {
pub system_fingerprint: Option<String>,
}
/// Response message structure for ChatCompletionResponse (different from request ChatMessage)
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatCompletionMessage {
pub role: String, // Always "assistant" for responses
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
/// Reasoning content for O1-style models (SGLang extension)
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_content: Option<String>,
// Note: function_call is deprecated and not included
// Note: refusal, annotations, audio are not added yet
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatChoice {
pub index: u32,
pub message: ChatMessage,
pub message: ChatCompletionMessage,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<ChatLogProbs>,
pub finish_reason: Option<String>, // "stop", "length", "tool_calls", "content_filter", "function_call"

View File

@@ -8,6 +8,7 @@ use axum::{
extract::Request,
http::{HeaderMap, StatusCode},
response::{IntoResponse, Response},
Json,
};
use tracing::{debug, error, info, warn};
@@ -18,8 +19,9 @@ use crate::metrics::RouterMetrics;
use crate::policies::PolicyRegistry;
use crate::protocols::spec::ChatMessage;
use crate::protocols::spec::{
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
ResponsesGetParams, ResponsesRequest, StringOrArray, Tool, ToolChoice,
ChatChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse,
CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, ResponsesGetParams,
ResponsesRequest, StringOrArray, Tool, ToolChoice, Usage,
};
use crate::reasoning_parser::ParserFactory;
use crate::routers::RouterTrait;
@@ -30,6 +32,7 @@ use crate::tokenizer::traits::Tokenizer;
use crate::tokenizer::HuggingFaceTokenizer;
use crate::tool_parser::ParserRegistry;
use serde_json::Value;
use std::time::{SystemTime, UNIX_EPOCH};
use tokio_stream::StreamExt;
use uuid::Uuid;
@@ -648,35 +651,98 @@ impl GrpcRouter {
Err(e) => return fail_fmt("Failed to start generation: ", &e),
};
// Get the single Complete response
let gen_response = match stream.next().await {
Some(Ok(r)) => r,
Some(Err(e)) => return fail_fmt("Failed to get GenerateResponse: ", &e),
None => return fail_str("No response from server"),
// Collect all responses (for n>1 support)
let mut all_responses = Vec::new();
while let Some(response) = stream.next().await {
match response {
Ok(gen_response) => match gen_response.response {
Some(proto::generate_response::Response::Complete(complete)) => {
all_responses.push(complete);
}
Some(proto::generate_response::Response::Error(err)) => {
error!("Generation failed for one choice: {}", err.message);
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Generation failed: {}", err.message),
)
.into_response();
}
Some(proto::generate_response::Response::Chunk(_)) => {
return fail_str("Unexpected chunk response for non-streaming request")
}
None => return fail_str("Empty response from server"),
},
Err(e) => return fail_fmt("Failed to get GenerateResponse: ", &e),
}
}
if all_responses.is_empty() {
return fail_str("No responses from server");
}
// Process each response into a ChatChoice
let mut choices = Vec::new();
for (index, complete) in all_responses.iter().enumerate() {
match self
.process_single_choice(complete, index, original_request, &mut stop_decoder)
.await
{
Ok(choice) => choices.push(choice),
Err(e) => {
error!("Failed to process choice {}: {}", index, e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to process choice {}: {}", index, e),
)
.into_response();
}
}
}
// Aggregate usage information from all responses
let total_prompt_tokens: u32 = all_responses.iter().map(|r| r.prompt_tokens as u32).sum();
let total_completion_tokens: u32 = all_responses
.iter()
.map(|r| r.completion_tokens as u32)
.sum();
let usage = Usage {
prompt_tokens: total_prompt_tokens,
completion_tokens: total_completion_tokens,
total_tokens: total_prompt_tokens + total_completion_tokens,
completion_tokens_details: None,
};
// Extract the expected variant early
let complete = match gen_response.response {
Some(proto::generate_response::Response::Complete(c)) => c,
Some(proto::generate_response::Response::Error(err)) => {
error!("Generation failed: {}", err.message);
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Generation failed: {}", err.message),
)
.into_response();
}
Some(proto::generate_response::Response::Chunk(_)) => {
return fail_str("Unexpected chunk response for non-streaming request")
}
None => return fail_str("Empty response from server"),
// Build final ChatCompletionResponse
let response = ChatCompletionResponse {
id: format!("chatcmpl-{}", Uuid::new_v4()),
object: "chat.completion".to_string(),
created: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
model: original_request.model.clone(),
choices,
usage: Some(usage),
system_fingerprint: None,
};
// Serialize and return JSON response
Json(response).into_response()
}
/// Process a single GenerateComplete response into a ChatChoice
async fn process_single_choice(
&self,
complete: &proto::GenerateComplete,
index: usize,
original_request: &ChatCompletionRequest,
stop_decoder: &mut crate::tokenizer::stop::StopSequenceDecoder,
) -> Result<ChatChoice, String> {
stop_decoder.reset();
// Decode tokens
let outputs = match stop_decoder.process_tokens(&complete.output_ids) {
Ok(o) => o,
Err(e) => return fail_fmt("Failed to process tokens: ", &e),
};
let outputs = stop_decoder
.process_tokens(&complete.output_ids)
.map_err(|e| format!("Failed to process tokens: {}", e))?;
// Accumulate text with early breaks
let mut final_text = String::new();
@@ -697,8 +763,119 @@ impl GrpcRouter {
final_text.push_str(&t);
}
// TODO: Create proper OpenAI-compatible response
(StatusCode::OK, format!("Final text: {}", final_text)).into_response()
// Step 1: Handle reasoning content parsing
let mut reasoning_text: Option<String> = None;
let mut processed_text = final_text;
// Check if reasoning parsing is enabled and separate_reasoning is requested
if original_request.separate_reasoning {
if let Ok(mut parser) = self
.reasoning_parser_factory
.create(&original_request.model)
{
match parser.detect_and_parse_reasoning(&processed_text) {
Ok(result) => {
if !result.reasoning_text.is_empty() {
reasoning_text = Some(result.reasoning_text);
}
processed_text = result.normal_text;
}
Err(e) => {
return Err(format!("Reasoning parsing error: {}", e));
}
}
}
}
// Step 2: Handle tool call parsing
let mut tool_calls: Option<Vec<crate::protocols::spec::ToolCall>> = None;
// Check if tool calls should be processed
let tool_choice_enabled = !matches!(
&original_request.tool_choice,
Some(ToolChoice::Value(
crate::protocols::spec::ToolChoiceValue::None
))
);
if tool_choice_enabled && original_request.tools.is_some() {
if let Some(parser) = self
.tool_parser_registry
.get_parser(&original_request.model)
{
match parser.parse_complete(&processed_text).await {
Ok(parsed_tool_calls) => {
if !parsed_tool_calls.is_empty() {
let spec_tool_calls = parsed_tool_calls
.into_iter()
.map(|tc| crate::protocols::spec::ToolCall {
id: tc.id,
tool_type: "function".to_string(),
function: crate::protocols::spec::FunctionCallResponse {
name: tc.function.name,
arguments: Some(
serde_json::to_string(&tc.function.arguments)
.unwrap_or_else(|_| "{}".to_string()),
),
},
})
.collect();
tool_calls = Some(spec_tool_calls);
processed_text = String::new();
}
}
Err(e) => {
error!("Tool call parsing error: {}", e);
// Continue without tool calls rather than failing
}
}
}
}
// Step 3: Use finish reason directly from proto (already OpenAI-compatible string)
let finish_reason_str = &complete.finish_reason;
// Override finish reason if we have tool calls
let final_finish_reason_str = if tool_calls.is_some() {
"tool_calls"
} else {
finish_reason_str
};
// Extract matched_stop information from proto
let matched_stop = match &complete.matched_stop {
Some(proto::generate_complete::MatchedStop::MatchedTokenId(token_id)) => Some(
serde_json::Value::Number(serde_json::Number::from(*token_id)),
),
Some(proto::generate_complete::MatchedStop::MatchedStopStr(stop_str)) => {
Some(serde_json::Value::String(stop_str.clone()))
}
None => None,
};
// Step 4: Build ChatCompletionMessage (proper response message type)
let chat_message = ChatCompletionMessage {
role: "assistant".to_string(),
content: if processed_text.is_empty() {
None
} else {
Some(processed_text)
},
tool_calls,
reasoning_content: reasoning_text,
};
// Step 5: Build ChatChoice
let choice = ChatChoice {
index: index as u32,
message: chat_message,
logprobs: None,
finish_reason: Some(final_finish_reason_str.to_string()),
matched_stop,
hidden_states: None,
};
Ok(choice)
}
}