[router][grpc] Support E2E non-stream chat completions (#10980)

This commit is contained in:
Chang Su
2025-09-26 22:02:06 -07:00
committed by GitHub
parent bd95944cf6
commit 37f3325b06
8 changed files with 325 additions and 136 deletions

View File

@@ -13,7 +13,7 @@ import sys
import threading import threading
import time import time
import uuid import uuid
from typing import Any, Dict, List, Optional, Union from typing import Any, AsyncGenerator, Dict, List, Optional, Union
import grpc import grpc
import zmq import zmq
@@ -156,7 +156,7 @@ class GrpcRequestManager:
obj: TokenizedGenerateReqInput, obj: TokenizedGenerateReqInput,
request_id: Optional[str] = None, request_id: Optional[str] = None,
grpc_context: Optional[grpc.aio.ServicerContext] = None, grpc_context: Optional[grpc.aio.ServicerContext] = None,
): ) -> AsyncGenerator[Union[Dict, List[Dict]], None]:
""" """
Submit a generation request to the scheduler with n>1 parallel sampling support. Submit a generation request to the scheduler with n>1 parallel sampling support.

View File

@@ -321,14 +321,14 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
logger.info(f"Sending health check request to request manager...") logger.info(f"Sending health check request to request manager...")
# Submit and wait for response # Submit and wait for response
output_queue = await self.request_manager.generate_request( output_generator = self.request_manager.generate_request(
health_request, request_id=rid health_request, request_id=rid
) )
try: try:
# Wait for response with configurable timeout # Get first response with timeout
response = await asyncio.wait_for( response = await asyncio.wait_for(
output_queue.get(), timeout=HEALTH_CHECK_TIMEOUT output_generator.__anext__(), timeout=HEALTH_CHECK_TIMEOUT
) )
# Clean up # Clean up
@@ -492,13 +492,32 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
) -> sglang_scheduler_pb2.GenerateResponse: ) -> sglang_scheduler_pb2.GenerateResponse:
"""Create a completion response.""" """Create a completion response."""
# Determine finish reason # Extract meta info and finish reason details
finish_reason = sglang_scheduler_pb2.GenerateComplete.STOP
meta_info = output.get("meta_info", {}) meta_info = output.get("meta_info", {})
if meta_info.get("finish_reason") == "length": finish_reason_data = meta_info.get("finish_reason")
finish_reason = sglang_scheduler_pb2.GenerateComplete.LENGTH
elif meta_info.get("finish_reason") == "eos_token": # Determine finish reason, default is stop
finish_reason = sglang_scheduler_pb2.GenerateComplete.EOS_TOKEN finish_reason = "stop"
if finish_reason_data:
if isinstance(finish_reason_data, dict):
finish_reason_type = finish_reason_data.get("type")
else:
# Handle legacy string format
finish_reason_type = finish_reason_data
if finish_reason_type == "length":
finish_reason = "length"
elif finish_reason_type == "abort":
finish_reason = "abort"
# Extract matched_stop information
matched_stop_kwargs = {}
if isinstance(finish_reason_data, dict) and "matched" in finish_reason_data:
matched = finish_reason_data["matched"]
if isinstance(matched, int):
matched_stop_kwargs["matched_token_id"] = matched
elif isinstance(matched, str):
matched_stop_kwargs["matched_stop_str"] = matched
return sglang_scheduler_pb2.GenerateResponse( return sglang_scheduler_pb2.GenerateResponse(
request_id=request_id, request_id=request_id,
@@ -510,6 +529,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
"completion_tokens", len(output.get("token_ids", [])) "completion_tokens", len(output.get("token_ids", []))
), ),
cached_tokens=meta_info.get("cached_tokens", 0), cached_tokens=meta_info.get("cached_tokens", 0),
**matched_stop_kwargs,
), ),
) )

View File

@@ -185,20 +185,8 @@ message GenerateComplete {
// Final output // Final output
repeated uint32 output_ids = 1; repeated uint32 output_ids = 1;
// Finish reason // Finish reason as OpenAI-compatible string ("stop", "length", "abort")
enum FinishReason { string finish_reason = 2;
// The model generated a stop sequence.
STOP = 0;
// The model reached the maximum generation length.
LENGTH = 1;
// The model generated an end-of-sequence (EOS) token.
EOS_TOKEN = 2;
// The model generated a user-provided stop string.
STOP_STR = 3;
// The request was aborted by the user or system.
ABORT = 4;
}
FinishReason finish_reason = 2;
// Token usage counts // Token usage counts
int32 prompt_tokens = 3; int32 prompt_tokens = 3;
@@ -210,6 +198,12 @@ message GenerateComplete {
// All hidden states if requested // All hidden states if requested
repeated HiddenStates all_hidden_states = 7; repeated HiddenStates all_hidden_states = 7;
// Matched stop information (for stop sequences)
oneof matched_stop {
uint32 matched_token_id = 8;
string matched_stop_str = 9;
}
} }
message GenerateError { message GenerateError {

File diff suppressed because one or more lines are too long

View File

@@ -3,7 +3,6 @@ import datetime
from google.protobuf import timestamp_pb2 as _timestamp_pb2 from google.protobuf import timestamp_pb2 as _timestamp_pb2
from google.protobuf import struct_pb2 as _struct_pb2 from google.protobuf import struct_pb2 as _struct_pb2
from google.protobuf.internal import containers as _containers from google.protobuf.internal import containers as _containers
from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message from google.protobuf import message as _message
from collections.abc import Iterable as _Iterable, Mapping as _Mapping from collections.abc import Iterable as _Iterable, Mapping as _Mapping
@@ -179,19 +178,7 @@ class GenerateStreamChunk(_message.Message):
def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ...) -> None: ... def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ...) -> None: ...
class GenerateComplete(_message.Message): class GenerateComplete(_message.Message):
__slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "all_logprobs", "all_hidden_states") __slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "all_logprobs", "all_hidden_states", "matched_token_id", "matched_stop_str")
class FinishReason(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = ()
STOP: _ClassVar[GenerateComplete.FinishReason]
LENGTH: _ClassVar[GenerateComplete.FinishReason]
EOS_TOKEN: _ClassVar[GenerateComplete.FinishReason]
STOP_STR: _ClassVar[GenerateComplete.FinishReason]
ABORT: _ClassVar[GenerateComplete.FinishReason]
STOP: GenerateComplete.FinishReason
LENGTH: GenerateComplete.FinishReason
EOS_TOKEN: GenerateComplete.FinishReason
STOP_STR: GenerateComplete.FinishReason
ABORT: GenerateComplete.FinishReason
OUTPUT_IDS_FIELD_NUMBER: _ClassVar[int] OUTPUT_IDS_FIELD_NUMBER: _ClassVar[int]
FINISH_REASON_FIELD_NUMBER: _ClassVar[int] FINISH_REASON_FIELD_NUMBER: _ClassVar[int]
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int] PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
@@ -199,14 +186,18 @@ class GenerateComplete(_message.Message):
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int] CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
ALL_LOGPROBS_FIELD_NUMBER: _ClassVar[int] ALL_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
ALL_HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int] ALL_HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
MATCHED_TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
MATCHED_STOP_STR_FIELD_NUMBER: _ClassVar[int]
output_ids: _containers.RepeatedScalarFieldContainer[int] output_ids: _containers.RepeatedScalarFieldContainer[int]
finish_reason: GenerateComplete.FinishReason finish_reason: str
prompt_tokens: int prompt_tokens: int
completion_tokens: int completion_tokens: int
cached_tokens: int cached_tokens: int
all_logprobs: _containers.RepeatedCompositeFieldContainer[LogProbs] all_logprobs: _containers.RepeatedCompositeFieldContainer[LogProbs]
all_hidden_states: _containers.RepeatedCompositeFieldContainer[HiddenStates] all_hidden_states: _containers.RepeatedCompositeFieldContainer[HiddenStates]
def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., finish_reason: _Optional[_Union[GenerateComplete.FinishReason, str]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., all_logprobs: _Optional[_Iterable[_Union[LogProbs, _Mapping]]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ...) -> None: ... matched_token_id: int
matched_stop_str: str
def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., finish_reason: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., all_logprobs: _Optional[_Iterable[_Union[LogProbs, _Mapping]]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ..., matched_token_id: _Optional[int] = ..., matched_stop_str: _Optional[str] = ...) -> None: ...
class GenerateError(_message.Message): class GenerateError(_message.Message):
__slots__ = ("message", "http_status_code", "details") __slots__ = ("message", "http_status_code", "details")

View File

@@ -185,20 +185,8 @@ message GenerateComplete {
// Final output // Final output
repeated uint32 output_ids = 1; repeated uint32 output_ids = 1;
// Finish reason // Finish reason as OpenAI-compatible string ("stop", "length", "abort")
enum FinishReason { string finish_reason = 2;
// The model generated a stop sequence.
STOP = 0;
// The model reached the maximum generation length.
LENGTH = 1;
// The model generated an end-of-sequence (EOS) token.
EOS_TOKEN = 2;
// The model generated a user-provided stop string.
STOP_STR = 3;
// The request was aborted by the user or system.
ABORT = 4;
}
FinishReason finish_reason = 2;
// Token usage counts // Token usage counts
int32 prompt_tokens = 3; int32 prompt_tokens = 3;
@@ -210,6 +198,12 @@ message GenerateComplete {
// All hidden states if requested // All hidden states if requested
repeated HiddenStates all_hidden_states = 7; repeated HiddenStates all_hidden_states = 7;
// Matched stop information (for stop sequences)
oneof matched_stop {
uint32 matched_token_id = 8;
string matched_stop_str = 9;
}
} }
message GenerateError { message GenerateError {

View File

@@ -423,10 +423,25 @@ pub struct ChatCompletionResponse {
pub system_fingerprint: Option<String>, pub system_fingerprint: Option<String>,
} }
/// Response message structure for ChatCompletionResponse (different from request ChatMessage)
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatCompletionMessage {
pub role: String, // Always "assistant" for responses
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
/// Reasoning content for O1-style models (SGLang extension)
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_content: Option<String>,
// Note: function_call is deprecated and not included
// Note: refusal, annotations, audio are not added yet
}
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatChoice { pub struct ChatChoice {
pub index: u32, pub index: u32,
pub message: ChatMessage, pub message: ChatCompletionMessage,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<ChatLogProbs>, pub logprobs: Option<ChatLogProbs>,
pub finish_reason: Option<String>, // "stop", "length", "tool_calls", "content_filter", "function_call" pub finish_reason: Option<String>, // "stop", "length", "tool_calls", "content_filter", "function_call"

View File

@@ -8,6 +8,7 @@ use axum::{
extract::Request, extract::Request,
http::{HeaderMap, StatusCode}, http::{HeaderMap, StatusCode},
response::{IntoResponse, Response}, response::{IntoResponse, Response},
Json,
}; };
use tracing::{debug, error, info, warn}; use tracing::{debug, error, info, warn};
@@ -18,8 +19,9 @@ use crate::metrics::RouterMetrics;
use crate::policies::PolicyRegistry; use crate::policies::PolicyRegistry;
use crate::protocols::spec::ChatMessage; use crate::protocols::spec::ChatMessage;
use crate::protocols::spec::{ use crate::protocols::spec::{
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, ChatChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse,
ResponsesGetParams, ResponsesRequest, StringOrArray, Tool, ToolChoice, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, ResponsesGetParams,
ResponsesRequest, StringOrArray, Tool, ToolChoice, Usage,
}; };
use crate::reasoning_parser::ParserFactory; use crate::reasoning_parser::ParserFactory;
use crate::routers::RouterTrait; use crate::routers::RouterTrait;
@@ -30,6 +32,7 @@ use crate::tokenizer::traits::Tokenizer;
use crate::tokenizer::HuggingFaceTokenizer; use crate::tokenizer::HuggingFaceTokenizer;
use crate::tool_parser::ParserRegistry; use crate::tool_parser::ParserRegistry;
use serde_json::Value; use serde_json::Value;
use std::time::{SystemTime, UNIX_EPOCH};
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use uuid::Uuid; use uuid::Uuid;
@@ -648,35 +651,98 @@ impl GrpcRouter {
Err(e) => return fail_fmt("Failed to start generation: ", &e), Err(e) => return fail_fmt("Failed to start generation: ", &e),
}; };
// Get the single Complete response // Collect all responses (for n>1 support)
let gen_response = match stream.next().await { let mut all_responses = Vec::new();
Some(Ok(r)) => r, while let Some(response) = stream.next().await {
Some(Err(e)) => return fail_fmt("Failed to get GenerateResponse: ", &e), match response {
None => return fail_str("No response from server"), Ok(gen_response) => match gen_response.response {
Some(proto::generate_response::Response::Complete(complete)) => {
all_responses.push(complete);
}
Some(proto::generate_response::Response::Error(err)) => {
error!("Generation failed for one choice: {}", err.message);
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Generation failed: {}", err.message),
)
.into_response();
}
Some(proto::generate_response::Response::Chunk(_)) => {
return fail_str("Unexpected chunk response for non-streaming request")
}
None => return fail_str("Empty response from server"),
},
Err(e) => return fail_fmt("Failed to get GenerateResponse: ", &e),
}
}
if all_responses.is_empty() {
return fail_str("No responses from server");
}
// Process each response into a ChatChoice
let mut choices = Vec::new();
for (index, complete) in all_responses.iter().enumerate() {
match self
.process_single_choice(complete, index, original_request, &mut stop_decoder)
.await
{
Ok(choice) => choices.push(choice),
Err(e) => {
error!("Failed to process choice {}: {}", index, e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to process choice {}: {}", index, e),
)
.into_response();
}
}
}
// Aggregate usage information from all responses
let total_prompt_tokens: u32 = all_responses.iter().map(|r| r.prompt_tokens as u32).sum();
let total_completion_tokens: u32 = all_responses
.iter()
.map(|r| r.completion_tokens as u32)
.sum();
let usage = Usage {
prompt_tokens: total_prompt_tokens,
completion_tokens: total_completion_tokens,
total_tokens: total_prompt_tokens + total_completion_tokens,
completion_tokens_details: None,
}; };
// Extract the expected variant early // Build final ChatCompletionResponse
let complete = match gen_response.response { let response = ChatCompletionResponse {
Some(proto::generate_response::Response::Complete(c)) => c, id: format!("chatcmpl-{}", Uuid::new_v4()),
Some(proto::generate_response::Response::Error(err)) => { object: "chat.completion".to_string(),
error!("Generation failed: {}", err.message); created: SystemTime::now()
return ( .duration_since(UNIX_EPOCH)
StatusCode::INTERNAL_SERVER_ERROR, .unwrap_or_default()
format!("Generation failed: {}", err.message), .as_secs(),
) model: original_request.model.clone(),
.into_response(); choices,
} usage: Some(usage),
Some(proto::generate_response::Response::Chunk(_)) => { system_fingerprint: None,
return fail_str("Unexpected chunk response for non-streaming request")
}
None => return fail_str("Empty response from server"),
}; };
// Serialize and return JSON response
Json(response).into_response()
}
/// Process a single GenerateComplete response into a ChatChoice
async fn process_single_choice(
&self,
complete: &proto::GenerateComplete,
index: usize,
original_request: &ChatCompletionRequest,
stop_decoder: &mut crate::tokenizer::stop::StopSequenceDecoder,
) -> Result<ChatChoice, String> {
stop_decoder.reset();
// Decode tokens // Decode tokens
let outputs = match stop_decoder.process_tokens(&complete.output_ids) { let outputs = stop_decoder
Ok(o) => o, .process_tokens(&complete.output_ids)
Err(e) => return fail_fmt("Failed to process tokens: ", &e), .map_err(|e| format!("Failed to process tokens: {}", e))?;
};
// Accumulate text with early breaks // Accumulate text with early breaks
let mut final_text = String::new(); let mut final_text = String::new();
@@ -697,8 +763,119 @@ impl GrpcRouter {
final_text.push_str(&t); final_text.push_str(&t);
} }
// TODO: Create proper OpenAI-compatible response // Step 1: Handle reasoning content parsing
(StatusCode::OK, format!("Final text: {}", final_text)).into_response() let mut reasoning_text: Option<String> = None;
let mut processed_text = final_text;
// Check if reasoning parsing is enabled and separate_reasoning is requested
if original_request.separate_reasoning {
if let Ok(mut parser) = self
.reasoning_parser_factory
.create(&original_request.model)
{
match parser.detect_and_parse_reasoning(&processed_text) {
Ok(result) => {
if !result.reasoning_text.is_empty() {
reasoning_text = Some(result.reasoning_text);
}
processed_text = result.normal_text;
}
Err(e) => {
return Err(format!("Reasoning parsing error: {}", e));
}
}
}
}
// Step 2: Handle tool call parsing
let mut tool_calls: Option<Vec<crate::protocols::spec::ToolCall>> = None;
// Check if tool calls should be processed
let tool_choice_enabled = !matches!(
&original_request.tool_choice,
Some(ToolChoice::Value(
crate::protocols::spec::ToolChoiceValue::None
))
);
if tool_choice_enabled && original_request.tools.is_some() {
if let Some(parser) = self
.tool_parser_registry
.get_parser(&original_request.model)
{
match parser.parse_complete(&processed_text).await {
Ok(parsed_tool_calls) => {
if !parsed_tool_calls.is_empty() {
let spec_tool_calls = parsed_tool_calls
.into_iter()
.map(|tc| crate::protocols::spec::ToolCall {
id: tc.id,
tool_type: "function".to_string(),
function: crate::protocols::spec::FunctionCallResponse {
name: tc.function.name,
arguments: Some(
serde_json::to_string(&tc.function.arguments)
.unwrap_or_else(|_| "{}".to_string()),
),
},
})
.collect();
tool_calls = Some(spec_tool_calls);
processed_text = String::new();
}
}
Err(e) => {
error!("Tool call parsing error: {}", e);
// Continue without tool calls rather than failing
}
}
}
}
// Step 3: Use finish reason directly from proto (already OpenAI-compatible string)
let finish_reason_str = &complete.finish_reason;
// Override finish reason if we have tool calls
let final_finish_reason_str = if tool_calls.is_some() {
"tool_calls"
} else {
finish_reason_str
};
// Extract matched_stop information from proto
let matched_stop = match &complete.matched_stop {
Some(proto::generate_complete::MatchedStop::MatchedTokenId(token_id)) => Some(
serde_json::Value::Number(serde_json::Number::from(*token_id)),
),
Some(proto::generate_complete::MatchedStop::MatchedStopStr(stop_str)) => {
Some(serde_json::Value::String(stop_str.clone()))
}
None => None,
};
// Step 4: Build ChatCompletionMessage (proper response message type)
let chat_message = ChatCompletionMessage {
role: "assistant".to_string(),
content: if processed_text.is_empty() {
None
} else {
Some(processed_text)
},
tool_calls,
reasoning_content: reasoning_text,
};
// Step 5: Build ChatChoice
let choice = ChatChoice {
index: index as u32,
message: chat_message,
logprobs: None,
finish_reason: Some(final_finish_reason_str.to_string()),
matched_stop,
hidden_states: None,
};
Ok(choice)
} }
} }