[router][grpc] Add logprobs support to router (#11082)

This commit is contained in:
Chang Su
2025-09-29 15:55:06 -07:00
committed by GitHub
parent f065e5bea5
commit 5937a56d47
7 changed files with 323 additions and 96 deletions

View File

@@ -82,6 +82,7 @@ class GrpcReqState:
# Streaming state
stream_finished: bool = False
input_logprobs_sent: bool = False # Track if input logprobs were sent in streaming
# Token accumulation (for non-streaming)
output_ids: List[int] = dataclasses.field(default_factory=list)
@@ -516,19 +517,105 @@ class GrpcRequestManager:
},
}
# Add logprobs if available
# Accumulate input logprobs (only once, usually in first chunk)
if batch_out.input_token_logprobs_val and i < len(
batch_out.input_token_logprobs_val
):
if not state.input_token_logprobs_val:
state.input_token_logprobs_val.extend(
batch_out.input_token_logprobs_val[i]
)
if batch_out.input_token_logprobs_idx and i < len(
batch_out.input_token_logprobs_idx
):
state.input_token_logprobs_idx.extend(
batch_out.input_token_logprobs_idx[i]
)
if batch_out.input_top_logprobs_val and i < len(
batch_out.input_top_logprobs_val
):
state.input_top_logprobs_val.extend(
batch_out.input_top_logprobs_val[i]
)
if batch_out.input_top_logprobs_idx and i < len(
batch_out.input_top_logprobs_idx
):
state.input_top_logprobs_idx.extend(
batch_out.input_top_logprobs_idx[i]
)
# Send input logprobs based on mode
if state.input_token_logprobs_val:
if state.obj.stream and not state.input_logprobs_sent:
# Streaming: send input logprobs once in first chunk that has them
output_data["input_logprobs"] = {
"token_logprobs_val": state.input_token_logprobs_val,
"token_logprobs_idx": state.input_token_logprobs_idx,
"top_logprobs_val": state.input_top_logprobs_val,
"top_logprobs_idx": state.input_top_logprobs_idx,
}
state.input_logprobs_sent = True
elif not state.obj.stream and output_data["finished"]:
# Non-streaming: send input logprobs in final chunk
output_data["input_logprobs"] = {
"token_logprobs_val": state.input_token_logprobs_val,
"token_logprobs_idx": state.input_token_logprobs_idx,
"top_logprobs_val": state.input_top_logprobs_val,
"top_logprobs_idx": state.input_top_logprobs_idx,
}
# Add output logprobs if available (RAW - no detokenization!)
if batch_out.output_token_logprobs_val and i < len(
batch_out.output_token_logprobs_val
):
output_data["logprobs"] = {
"tokens": batch_out.output_token_logprobs_val[i],
"top_logprobs": (
# Accumulate in state first
state.output_token_logprobs_val.extend(
batch_out.output_token_logprobs_val[i]
)
if batch_out.output_token_logprobs_idx and i < len(
batch_out.output_token_logprobs_idx
):
state.output_token_logprobs_idx.extend(
batch_out.output_token_logprobs_idx[i]
)
if batch_out.output_top_logprobs_val and i < len(
batch_out.output_top_logprobs_val
):
state.output_top_logprobs_val.extend(
batch_out.output_top_logprobs_val[i]
if batch_out.output_top_logprobs_val
and i < len(batch_out.output_top_logprobs_val)
else None
),
}
)
if batch_out.output_top_logprobs_idx and i < len(
batch_out.output_top_logprobs_idx
):
state.output_top_logprobs_idx.extend(
batch_out.output_top_logprobs_idx[i]
)
if state.obj.stream:
# For streaming: send incremental logprobs (only new tokens in this chunk)
# NOTE: this is different than TokenizerManager, which always accumulates
def get_part(attr_name):
source_list = getattr(batch_out, attr_name, None)
return (
source_list[i]
if source_list and i < len(source_list)
else []
)
output_data["output_logprobs"] = {
"token_logprobs_val": batch_out.output_token_logprobs_val[i],
"token_logprobs_idx": get_part("output_token_logprobs_idx"),
"top_logprobs_val": get_part("output_top_logprobs_val"),
"top_logprobs_idx": get_part("output_top_logprobs_idx"),
}
elif output_data["finished"]:
# Non-streaming: send cumulative output logprobs in final chunk
output_data["output_logprobs"] = {
"token_logprobs_val": state.output_token_logprobs_val,
"token_logprobs_idx": state.output_token_logprobs_idx,
"top_logprobs_val": state.output_top_logprobs_val,
"top_logprobs_idx": state.output_top_logprobs_idx,
}
# Update state for accumulation
if output_data["token_ids"]:

View File

@@ -472,11 +472,51 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
ignore_eos=grpc_params.ignore_eos,
)
def _convert_logprobs_to_proto(
self, logprobs_data: Dict
) -> Optional[sglang_scheduler_pb2.LogProbs]:
"""Convert logprobs dict to proto LogProbs format (transport RAW data only)."""
if not logprobs_data:
return None
token_logprobs_val = logprobs_data.get("token_logprobs_val", [])
token_logprobs_idx = logprobs_data.get("token_logprobs_idx", [])
top_logprobs_val = logprobs_data.get("top_logprobs_val", [])
top_logprobs_idx = logprobs_data.get("top_logprobs_idx", [])
# Build TopLogProbs entries
top_logprobs_proto = []
if top_logprobs_val and top_logprobs_idx:
for val_list, idx_list in zip(top_logprobs_val, top_logprobs_idx):
top_logprobs_proto.append(
sglang_scheduler_pb2.TopLogProbs(
values=val_list,
token_ids=idx_list,
)
)
return sglang_scheduler_pb2.LogProbs(
token_logprobs=token_logprobs_val,
token_ids=token_logprobs_idx,
top_logprobs=top_logprobs_proto,
)
def _create_chunk_response(
self, request_id: str, output: Dict
) -> sglang_scheduler_pb2.GenerateResponse:
"""Create a streaming chunk response."""
meta_info = output.get("meta_info", {})
# Convert output logprobs if present
output_logprobs_proto = self._convert_logprobs_to_proto(
output.get("output_logprobs")
)
# Convert input logprobs if present (only in first chunk)
input_logprobs_proto = self._convert_logprobs_to_proto(
output.get("input_logprobs")
)
return sglang_scheduler_pb2.GenerateResponse(
request_id=request_id,
chunk=sglang_scheduler_pb2.GenerateStreamChunk(
@@ -484,6 +524,8 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
prompt_tokens=meta_info.get("prompt_tokens", 0),
completion_tokens=meta_info.get("completion_tokens", 0),
cached_tokens=meta_info.get("cached_tokens", 0),
output_logprobs=output_logprobs_proto,
input_logprobs=input_logprobs_proto,
),
)
@@ -519,6 +561,16 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
elif isinstance(matched, str):
matched_stop_kwargs["matched_stop_str"] = matched
# Convert output logprobs if present
output_logprobs_proto = self._convert_logprobs_to_proto(
output.get("output_logprobs")
)
# Convert input logprobs if present
input_logprobs_proto = self._convert_logprobs_to_proto(
output.get("input_logprobs")
)
return sglang_scheduler_pb2.GenerateResponse(
request_id=request_id,
complete=sglang_scheduler_pb2.GenerateComplete(
@@ -529,6 +581,8 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
"completion_tokens", len(output.get("token_ids", []))
),
cached_tokens=meta_info.get("cached_tokens", 0),
output_logprobs=output_logprobs_proto,
input_logprobs=input_logprobs_proto,
**matched_stop_kwargs,
),
)

View File

@@ -174,11 +174,14 @@ message GenerateStreamChunk {
int32 completion_tokens = 3;
int32 cached_tokens = 4;
// Logprobs (if requested)
LogProbs logprobs = 5;
// Output logprobs (if requested) - incremental for streaming
LogProbs output_logprobs = 5;
// Hidden states (if requested)
repeated float hidden_states = 6;
// Input logprobs (if requested) - only in first chunk
LogProbs input_logprobs = 7;
}
message GenerateComplete {
@@ -193,8 +196,8 @@ message GenerateComplete {
int32 completion_tokens = 4;
int32 cached_tokens = 5;
// All logprobs if requested
repeated LogProbs all_logprobs = 6;
// Output logprobs if requested (cumulative)
LogProbs output_logprobs = 6;
// All hidden states if requested
repeated HiddenStates all_hidden_states = 7;
@@ -204,6 +207,9 @@ message GenerateComplete {
uint32 matched_token_id = 8;
string matched_stop_str = 9;
}
// Input logprobs if requested (for prompt tokens)
LogProbs input_logprobs = 10;
}
message GenerateError {
@@ -218,15 +224,11 @@ message LogProbs {
// Top logprobs at each position
repeated TopLogProbs top_logprobs = 3;
// Decoded text for tokens
repeated string token_texts = 4;
}
message TopLogProbs {
repeated float values = 1;
repeated int32 token_ids = 2;
repeated string token_texts = 3;
}
message HiddenStates {

File diff suppressed because one or more lines are too long

View File

@@ -162,42 +162,46 @@ class GenerateResponse(_message.Message):
def __init__(self, request_id: _Optional[str] = ..., chunk: _Optional[_Union[GenerateStreamChunk, _Mapping]] = ..., complete: _Optional[_Union[GenerateComplete, _Mapping]] = ..., error: _Optional[_Union[GenerateError, _Mapping]] = ...) -> None: ...
class GenerateStreamChunk(_message.Message):
__slots__ = ("token_ids", "prompt_tokens", "completion_tokens", "cached_tokens", "logprobs", "hidden_states")
__slots__ = ("token_ids", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "hidden_states", "input_logprobs")
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
LOGPROBS_FIELD_NUMBER: _ClassVar[int]
OUTPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
token_ids: _containers.RepeatedScalarFieldContainer[int]
prompt_tokens: int
completion_tokens: int
cached_tokens: int
logprobs: LogProbs
output_logprobs: LogProbs
hidden_states: _containers.RepeatedScalarFieldContainer[float]
def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ...) -> None: ...
input_logprobs: LogProbs
def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., input_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ...) -> None: ...
class GenerateComplete(_message.Message):
__slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "all_logprobs", "all_hidden_states", "matched_token_id", "matched_stop_str")
__slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "all_hidden_states", "matched_token_id", "matched_stop_str", "input_logprobs")
OUTPUT_IDS_FIELD_NUMBER: _ClassVar[int]
FINISH_REASON_FIELD_NUMBER: _ClassVar[int]
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
ALL_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
OUTPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
ALL_HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
MATCHED_TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
MATCHED_STOP_STR_FIELD_NUMBER: _ClassVar[int]
INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
output_ids: _containers.RepeatedScalarFieldContainer[int]
finish_reason: str
prompt_tokens: int
completion_tokens: int
cached_tokens: int
all_logprobs: _containers.RepeatedCompositeFieldContainer[LogProbs]
output_logprobs: LogProbs
all_hidden_states: _containers.RepeatedCompositeFieldContainer[HiddenStates]
matched_token_id: int
matched_stop_str: str
def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., finish_reason: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., all_logprobs: _Optional[_Iterable[_Union[LogProbs, _Mapping]]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ..., matched_token_id: _Optional[int] = ..., matched_stop_str: _Optional[str] = ...) -> None: ...
input_logprobs: LogProbs
def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., finish_reason: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ..., matched_token_id: _Optional[int] = ..., matched_stop_str: _Optional[str] = ..., input_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ...) -> None: ...
class GenerateError(_message.Message):
__slots__ = ("message", "http_status_code", "details")
@@ -210,26 +214,22 @@ class GenerateError(_message.Message):
def __init__(self, message: _Optional[str] = ..., http_status_code: _Optional[str] = ..., details: _Optional[str] = ...) -> None: ...
class LogProbs(_message.Message):
__slots__ = ("token_logprobs", "token_ids", "top_logprobs", "token_texts")
__slots__ = ("token_logprobs", "token_ids", "top_logprobs")
TOKEN_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
TOP_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
TOKEN_TEXTS_FIELD_NUMBER: _ClassVar[int]
token_logprobs: _containers.RepeatedScalarFieldContainer[float]
token_ids: _containers.RepeatedScalarFieldContainer[int]
top_logprobs: _containers.RepeatedCompositeFieldContainer[TopLogProbs]
token_texts: _containers.RepeatedScalarFieldContainer[str]
def __init__(self, token_logprobs: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., top_logprobs: _Optional[_Iterable[_Union[TopLogProbs, _Mapping]]] = ..., token_texts: _Optional[_Iterable[str]] = ...) -> None: ...
def __init__(self, token_logprobs: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., top_logprobs: _Optional[_Iterable[_Union[TopLogProbs, _Mapping]]] = ...) -> None: ...
class TopLogProbs(_message.Message):
__slots__ = ("values", "token_ids", "token_texts")
__slots__ = ("values", "token_ids")
VALUES_FIELD_NUMBER: _ClassVar[int]
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
TOKEN_TEXTS_FIELD_NUMBER: _ClassVar[int]
values: _containers.RepeatedScalarFieldContainer[float]
token_ids: _containers.RepeatedScalarFieldContainer[int]
token_texts: _containers.RepeatedScalarFieldContainer[str]
def __init__(self, values: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., token_texts: _Optional[_Iterable[str]] = ...) -> None: ...
def __init__(self, values: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ...) -> None: ...
class HiddenStates(_message.Message):
__slots__ = ("values", "layer", "position")