[router][grpc] Support E2E non-stream chat completions (#10980)

This commit is contained in:
Chang Su
2025-09-26 22:02:06 -07:00
committed by GitHub
parent bd95944cf6
commit 37f3325b06
8 changed files with 325 additions and 136 deletions

View File

@@ -13,7 +13,7 @@ import sys
import threading
import time
import uuid
from typing import Any, Dict, List, Optional, Union
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
import grpc
import zmq
@@ -156,7 +156,7 @@ class GrpcRequestManager:
obj: TokenizedGenerateReqInput,
request_id: Optional[str] = None,
grpc_context: Optional[grpc.aio.ServicerContext] = None,
):
) -> AsyncGenerator[Union[Dict, List[Dict]], None]:
"""
Submit a generation request to the scheduler with n>1 parallel sampling support.

View File

@@ -321,14 +321,14 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
logger.info(f"Sending health check request to request manager...")
# Submit and wait for response
output_queue = await self.request_manager.generate_request(
output_generator = self.request_manager.generate_request(
health_request, request_id=rid
)
try:
# Wait for response with configurable timeout
# Get first response with timeout
response = await asyncio.wait_for(
output_queue.get(), timeout=HEALTH_CHECK_TIMEOUT
output_generator.__anext__(), timeout=HEALTH_CHECK_TIMEOUT
)
# Clean up
@@ -492,13 +492,32 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
) -> sglang_scheduler_pb2.GenerateResponse:
"""Create a completion response."""
# Determine finish reason
finish_reason = sglang_scheduler_pb2.GenerateComplete.STOP
# Extract meta info and finish reason details
meta_info = output.get("meta_info", {})
if meta_info.get("finish_reason") == "length":
finish_reason = sglang_scheduler_pb2.GenerateComplete.LENGTH
elif meta_info.get("finish_reason") == "eos_token":
finish_reason = sglang_scheduler_pb2.GenerateComplete.EOS_TOKEN
finish_reason_data = meta_info.get("finish_reason")
# Determine finish reason, default is stop
finish_reason = "stop"
if finish_reason_data:
if isinstance(finish_reason_data, dict):
finish_reason_type = finish_reason_data.get("type")
else:
# Handle legacy string format
finish_reason_type = finish_reason_data
if finish_reason_type == "length":
finish_reason = "length"
elif finish_reason_type == "abort":
finish_reason = "abort"
# Extract matched_stop information
matched_stop_kwargs = {}
if isinstance(finish_reason_data, dict) and "matched" in finish_reason_data:
matched = finish_reason_data["matched"]
if isinstance(matched, int):
matched_stop_kwargs["matched_token_id"] = matched
elif isinstance(matched, str):
matched_stop_kwargs["matched_stop_str"] = matched
return sglang_scheduler_pb2.GenerateResponse(
request_id=request_id,
@@ -510,6 +529,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
"completion_tokens", len(output.get("token_ids", []))
),
cached_tokens=meta_info.get("cached_tokens", 0),
**matched_stop_kwargs,
),
)

View File

@@ -185,20 +185,8 @@ message GenerateComplete {
// Final output
repeated uint32 output_ids = 1;
// Finish reason
enum FinishReason {
// The model generated a stop sequence.
STOP = 0;
// The model reached the maximum generation length.
LENGTH = 1;
// The model generated an end-of-sequence (EOS) token.
EOS_TOKEN = 2;
// The model generated a user-provided stop string.
STOP_STR = 3;
// The request was aborted by the user or system.
ABORT = 4;
}
FinishReason finish_reason = 2;
// Finish reason as OpenAI-compatible string ("stop", "length", "abort")
string finish_reason = 2;
// Token usage counts
int32 prompt_tokens = 3;
@@ -210,6 +198,12 @@ message GenerateComplete {
// All hidden states if requested
repeated HiddenStates all_hidden_states = 7;
// Matched stop information (for stop sequences)
oneof matched_stop {
uint32 matched_token_id = 8;
string matched_stop_str = 9;
}
}
message GenerateError {

File diff suppressed because one or more lines are too long

View File

@@ -3,7 +3,6 @@ import datetime
from google.protobuf import timestamp_pb2 as _timestamp_pb2
from google.protobuf import struct_pb2 as _struct_pb2
from google.protobuf.internal import containers as _containers
from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from collections.abc import Iterable as _Iterable, Mapping as _Mapping
@@ -179,19 +178,7 @@ class GenerateStreamChunk(_message.Message):
def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ...) -> None: ...
class GenerateComplete(_message.Message):
__slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "all_logprobs", "all_hidden_states")
class FinishReason(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = ()
STOP: _ClassVar[GenerateComplete.FinishReason]
LENGTH: _ClassVar[GenerateComplete.FinishReason]
EOS_TOKEN: _ClassVar[GenerateComplete.FinishReason]
STOP_STR: _ClassVar[GenerateComplete.FinishReason]
ABORT: _ClassVar[GenerateComplete.FinishReason]
STOP: GenerateComplete.FinishReason
LENGTH: GenerateComplete.FinishReason
EOS_TOKEN: GenerateComplete.FinishReason
STOP_STR: GenerateComplete.FinishReason
ABORT: GenerateComplete.FinishReason
__slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "all_logprobs", "all_hidden_states", "matched_token_id", "matched_stop_str")
OUTPUT_IDS_FIELD_NUMBER: _ClassVar[int]
FINISH_REASON_FIELD_NUMBER: _ClassVar[int]
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
@@ -199,14 +186,18 @@ class GenerateComplete(_message.Message):
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
ALL_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
ALL_HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
MATCHED_TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
MATCHED_STOP_STR_FIELD_NUMBER: _ClassVar[int]
output_ids: _containers.RepeatedScalarFieldContainer[int]
finish_reason: GenerateComplete.FinishReason
finish_reason: str
prompt_tokens: int
completion_tokens: int
cached_tokens: int
all_logprobs: _containers.RepeatedCompositeFieldContainer[LogProbs]
all_hidden_states: _containers.RepeatedCompositeFieldContainer[HiddenStates]
def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., finish_reason: _Optional[_Union[GenerateComplete.FinishReason, str]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., all_logprobs: _Optional[_Iterable[_Union[LogProbs, _Mapping]]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ...) -> None: ...
matched_token_id: int
matched_stop_str: str
def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., finish_reason: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., all_logprobs: _Optional[_Iterable[_Union[LogProbs, _Mapping]]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ..., matched_token_id: _Optional[int] = ..., matched_stop_str: _Optional[str] = ...) -> None: ...
class GenerateError(_message.Message):
__slots__ = ("message", "http_status_code", "details")