[router][grpc] Support streaming for v1/chat/completions (#11179)

This commit is contained in:
Chang Su
2025-10-02 14:35:16 -07:00
committed by GitHub
parent 0618ad6dd5
commit 963175d5c0
30 changed files with 912 additions and 228 deletions

View File

@@ -578,7 +578,7 @@ class GrpcRequestManager:
batch_out.cached_tokens[i] if batch_out.cached_tokens else 0
),
"finish_reason": (
str(batch_out.finished_reasons[i])
batch_out.finished_reasons[i]
if batch_out.finished_reasons[i]
else None
),

View File

@@ -112,7 +112,6 @@ def _launch_scheduler_process_only(
pp_rank,
None,
writer,
None,
),
)
@@ -583,6 +582,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
cached_tokens=meta_info.get("cached_tokens", 0),
output_logprobs=output_logprobs_proto,
input_logprobs=input_logprobs_proto,
index=output.get("index", 0),
),
)
@@ -640,6 +640,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
cached_tokens=meta_info.get("cached_tokens", 0),
output_logprobs=output_logprobs_proto,
input_logprobs=input_logprobs_proto,
index=output.get("index", 0),
**matched_stop_kwargs,
),
)

View File

@@ -179,6 +179,9 @@ message GenerateStreamChunk {
// Input logprobs (if requested) - only in first chunk
InputLogProbs input_logprobs = 7;
// Index for ordering when n>1 (for parallel request multiplexing)
uint32 index = 8;
}
message GenerateComplete {
@@ -207,6 +210,9 @@ message GenerateComplete {
// Input logprobs if requested (for prompt tokens)
InputLogProbs input_logprobs = 10;
// Index for ordering when n>1 (for parallel request multiplexing)
uint32 index = 11;
}
message GenerateError {

File diff suppressed because one or more lines are too long

View File

@@ -160,7 +160,7 @@ class GenerateResponse(_message.Message):
def __init__(self, request_id: _Optional[str] = ..., chunk: _Optional[_Union[GenerateStreamChunk, _Mapping]] = ..., complete: _Optional[_Union[GenerateComplete, _Mapping]] = ..., error: _Optional[_Union[GenerateError, _Mapping]] = ...) -> None: ...
class GenerateStreamChunk(_message.Message):
__slots__ = ("token_ids", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "hidden_states", "input_logprobs")
__slots__ = ("token_ids", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "hidden_states", "input_logprobs", "index")
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
@@ -168,6 +168,7 @@ class GenerateStreamChunk(_message.Message):
OUTPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
INDEX_FIELD_NUMBER: _ClassVar[int]
token_ids: _containers.RepeatedScalarFieldContainer[int]
prompt_tokens: int
completion_tokens: int
@@ -175,10 +176,11 @@ class GenerateStreamChunk(_message.Message):
output_logprobs: OutputLogProbs
hidden_states: _containers.RepeatedScalarFieldContainer[float]
input_logprobs: InputLogProbs
def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[OutputLogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., input_logprobs: _Optional[_Union[InputLogProbs, _Mapping]] = ...) -> None: ...
index: int
def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[OutputLogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., input_logprobs: _Optional[_Union[InputLogProbs, _Mapping]] = ..., index: _Optional[int] = ...) -> None: ...
class GenerateComplete(_message.Message):
__slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "all_hidden_states", "matched_token_id", "matched_stop_str", "input_logprobs")
__slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "all_hidden_states", "matched_token_id", "matched_stop_str", "input_logprobs", "index")
OUTPUT_IDS_FIELD_NUMBER: _ClassVar[int]
FINISH_REASON_FIELD_NUMBER: _ClassVar[int]
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
@@ -189,6 +191,7 @@ class GenerateComplete(_message.Message):
MATCHED_TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
MATCHED_STOP_STR_FIELD_NUMBER: _ClassVar[int]
INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
INDEX_FIELD_NUMBER: _ClassVar[int]
output_ids: _containers.RepeatedScalarFieldContainer[int]
finish_reason: str
prompt_tokens: int
@@ -199,7 +202,8 @@ class GenerateComplete(_message.Message):
matched_token_id: int
matched_stop_str: str
input_logprobs: InputLogProbs
def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., finish_reason: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[OutputLogProbs, _Mapping]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ..., matched_token_id: _Optional[int] = ..., matched_stop_str: _Optional[str] = ..., input_logprobs: _Optional[_Union[InputLogProbs, _Mapping]] = ...) -> None: ...
index: int
def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., finish_reason: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[OutputLogProbs, _Mapping]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ..., matched_token_id: _Optional[int] = ..., matched_stop_str: _Optional[str] = ..., input_logprobs: _Optional[_Union[InputLogProbs, _Mapping]] = ..., index: _Optional[int] = ...) -> None: ...
class GenerateError(_message.Message):
__slots__ = ("message", "http_status_code", "details")