[router][grpc] Add logprobs support to router (#11082)
This commit is contained in:
@@ -82,6 +82,7 @@ class GrpcReqState:
|
|||||||
|
|
||||||
# Streaming state
|
# Streaming state
|
||||||
stream_finished: bool = False
|
stream_finished: bool = False
|
||||||
|
input_logprobs_sent: bool = False # Track if input logprobs were sent in streaming
|
||||||
|
|
||||||
# Token accumulation (for non-streaming)
|
# Token accumulation (for non-streaming)
|
||||||
output_ids: List[int] = dataclasses.field(default_factory=list)
|
output_ids: List[int] = dataclasses.field(default_factory=list)
|
||||||
@@ -516,19 +517,105 @@ class GrpcRequestManager:
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add logprobs if available
|
# Accumulate input logprobs (only once, usually in first chunk)
|
||||||
|
if batch_out.input_token_logprobs_val and i < len(
|
||||||
|
batch_out.input_token_logprobs_val
|
||||||
|
):
|
||||||
|
if not state.input_token_logprobs_val:
|
||||||
|
state.input_token_logprobs_val.extend(
|
||||||
|
batch_out.input_token_logprobs_val[i]
|
||||||
|
)
|
||||||
|
if batch_out.input_token_logprobs_idx and i < len(
|
||||||
|
batch_out.input_token_logprobs_idx
|
||||||
|
):
|
||||||
|
state.input_token_logprobs_idx.extend(
|
||||||
|
batch_out.input_token_logprobs_idx[i]
|
||||||
|
)
|
||||||
|
if batch_out.input_top_logprobs_val and i < len(
|
||||||
|
batch_out.input_top_logprobs_val
|
||||||
|
):
|
||||||
|
state.input_top_logprobs_val.extend(
|
||||||
|
batch_out.input_top_logprobs_val[i]
|
||||||
|
)
|
||||||
|
if batch_out.input_top_logprobs_idx and i < len(
|
||||||
|
batch_out.input_top_logprobs_idx
|
||||||
|
):
|
||||||
|
state.input_top_logprobs_idx.extend(
|
||||||
|
batch_out.input_top_logprobs_idx[i]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send input logprobs based on mode
|
||||||
|
if state.input_token_logprobs_val:
|
||||||
|
if state.obj.stream and not state.input_logprobs_sent:
|
||||||
|
# Streaming: send input logprobs once in first chunk that has them
|
||||||
|
output_data["input_logprobs"] = {
|
||||||
|
"token_logprobs_val": state.input_token_logprobs_val,
|
||||||
|
"token_logprobs_idx": state.input_token_logprobs_idx,
|
||||||
|
"top_logprobs_val": state.input_top_logprobs_val,
|
||||||
|
"top_logprobs_idx": state.input_top_logprobs_idx,
|
||||||
|
}
|
||||||
|
state.input_logprobs_sent = True
|
||||||
|
elif not state.obj.stream and output_data["finished"]:
|
||||||
|
# Non-streaming: send input logprobs in final chunk
|
||||||
|
output_data["input_logprobs"] = {
|
||||||
|
"token_logprobs_val": state.input_token_logprobs_val,
|
||||||
|
"token_logprobs_idx": state.input_token_logprobs_idx,
|
||||||
|
"top_logprobs_val": state.input_top_logprobs_val,
|
||||||
|
"top_logprobs_idx": state.input_top_logprobs_idx,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add output logprobs if available (RAW - no detokenization!)
|
||||||
if batch_out.output_token_logprobs_val and i < len(
|
if batch_out.output_token_logprobs_val and i < len(
|
||||||
batch_out.output_token_logprobs_val
|
batch_out.output_token_logprobs_val
|
||||||
):
|
):
|
||||||
output_data["logprobs"] = {
|
# Accumulate in state first
|
||||||
"tokens": batch_out.output_token_logprobs_val[i],
|
state.output_token_logprobs_val.extend(
|
||||||
"top_logprobs": (
|
batch_out.output_token_logprobs_val[i]
|
||||||
|
)
|
||||||
|
if batch_out.output_token_logprobs_idx and i < len(
|
||||||
|
batch_out.output_token_logprobs_idx
|
||||||
|
):
|
||||||
|
state.output_token_logprobs_idx.extend(
|
||||||
|
batch_out.output_token_logprobs_idx[i]
|
||||||
|
)
|
||||||
|
if batch_out.output_top_logprobs_val and i < len(
|
||||||
|
batch_out.output_top_logprobs_val
|
||||||
|
):
|
||||||
|
state.output_top_logprobs_val.extend(
|
||||||
batch_out.output_top_logprobs_val[i]
|
batch_out.output_top_logprobs_val[i]
|
||||||
if batch_out.output_top_logprobs_val
|
)
|
||||||
and i < len(batch_out.output_top_logprobs_val)
|
if batch_out.output_top_logprobs_idx and i < len(
|
||||||
else None
|
batch_out.output_top_logprobs_idx
|
||||||
),
|
):
|
||||||
}
|
state.output_top_logprobs_idx.extend(
|
||||||
|
batch_out.output_top_logprobs_idx[i]
|
||||||
|
)
|
||||||
|
|
||||||
|
if state.obj.stream:
|
||||||
|
# For streaming: send incremental logprobs (only new tokens in this chunk)
|
||||||
|
# NOTE: this is different than TokenizerManager, which always accumulates
|
||||||
|
def get_part(attr_name):
|
||||||
|
source_list = getattr(batch_out, attr_name, None)
|
||||||
|
return (
|
||||||
|
source_list[i]
|
||||||
|
if source_list and i < len(source_list)
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
|
||||||
|
output_data["output_logprobs"] = {
|
||||||
|
"token_logprobs_val": batch_out.output_token_logprobs_val[i],
|
||||||
|
"token_logprobs_idx": get_part("output_token_logprobs_idx"),
|
||||||
|
"top_logprobs_val": get_part("output_top_logprobs_val"),
|
||||||
|
"top_logprobs_idx": get_part("output_top_logprobs_idx"),
|
||||||
|
}
|
||||||
|
elif output_data["finished"]:
|
||||||
|
# Non-streaming: send cumulative output logprobs in final chunk
|
||||||
|
output_data["output_logprobs"] = {
|
||||||
|
"token_logprobs_val": state.output_token_logprobs_val,
|
||||||
|
"token_logprobs_idx": state.output_token_logprobs_idx,
|
||||||
|
"top_logprobs_val": state.output_top_logprobs_val,
|
||||||
|
"top_logprobs_idx": state.output_top_logprobs_idx,
|
||||||
|
}
|
||||||
|
|
||||||
# Update state for accumulation
|
# Update state for accumulation
|
||||||
if output_data["token_ids"]:
|
if output_data["token_ids"]:
|
||||||
|
|||||||
@@ -472,11 +472,51 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
|||||||
ignore_eos=grpc_params.ignore_eos,
|
ignore_eos=grpc_params.ignore_eos,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _convert_logprobs_to_proto(
|
||||||
|
self, logprobs_data: Dict
|
||||||
|
) -> Optional[sglang_scheduler_pb2.LogProbs]:
|
||||||
|
"""Convert logprobs dict to proto LogProbs format (transport RAW data only)."""
|
||||||
|
if not logprobs_data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
token_logprobs_val = logprobs_data.get("token_logprobs_val", [])
|
||||||
|
token_logprobs_idx = logprobs_data.get("token_logprobs_idx", [])
|
||||||
|
top_logprobs_val = logprobs_data.get("top_logprobs_val", [])
|
||||||
|
top_logprobs_idx = logprobs_data.get("top_logprobs_idx", [])
|
||||||
|
|
||||||
|
# Build TopLogProbs entries
|
||||||
|
top_logprobs_proto = []
|
||||||
|
if top_logprobs_val and top_logprobs_idx:
|
||||||
|
for val_list, idx_list in zip(top_logprobs_val, top_logprobs_idx):
|
||||||
|
top_logprobs_proto.append(
|
||||||
|
sglang_scheduler_pb2.TopLogProbs(
|
||||||
|
values=val_list,
|
||||||
|
token_ids=idx_list,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return sglang_scheduler_pb2.LogProbs(
|
||||||
|
token_logprobs=token_logprobs_val,
|
||||||
|
token_ids=token_logprobs_idx,
|
||||||
|
top_logprobs=top_logprobs_proto,
|
||||||
|
)
|
||||||
|
|
||||||
def _create_chunk_response(
|
def _create_chunk_response(
|
||||||
self, request_id: str, output: Dict
|
self, request_id: str, output: Dict
|
||||||
) -> sglang_scheduler_pb2.GenerateResponse:
|
) -> sglang_scheduler_pb2.GenerateResponse:
|
||||||
"""Create a streaming chunk response."""
|
"""Create a streaming chunk response."""
|
||||||
meta_info = output.get("meta_info", {})
|
meta_info = output.get("meta_info", {})
|
||||||
|
|
||||||
|
# Convert output logprobs if present
|
||||||
|
output_logprobs_proto = self._convert_logprobs_to_proto(
|
||||||
|
output.get("output_logprobs")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert input logprobs if present (only in first chunk)
|
||||||
|
input_logprobs_proto = self._convert_logprobs_to_proto(
|
||||||
|
output.get("input_logprobs")
|
||||||
|
)
|
||||||
|
|
||||||
return sglang_scheduler_pb2.GenerateResponse(
|
return sglang_scheduler_pb2.GenerateResponse(
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
chunk=sglang_scheduler_pb2.GenerateStreamChunk(
|
chunk=sglang_scheduler_pb2.GenerateStreamChunk(
|
||||||
@@ -484,6 +524,8 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
|||||||
prompt_tokens=meta_info.get("prompt_tokens", 0),
|
prompt_tokens=meta_info.get("prompt_tokens", 0),
|
||||||
completion_tokens=meta_info.get("completion_tokens", 0),
|
completion_tokens=meta_info.get("completion_tokens", 0),
|
||||||
cached_tokens=meta_info.get("cached_tokens", 0),
|
cached_tokens=meta_info.get("cached_tokens", 0),
|
||||||
|
output_logprobs=output_logprobs_proto,
|
||||||
|
input_logprobs=input_logprobs_proto,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -519,6 +561,16 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
|||||||
elif isinstance(matched, str):
|
elif isinstance(matched, str):
|
||||||
matched_stop_kwargs["matched_stop_str"] = matched
|
matched_stop_kwargs["matched_stop_str"] = matched
|
||||||
|
|
||||||
|
# Convert output logprobs if present
|
||||||
|
output_logprobs_proto = self._convert_logprobs_to_proto(
|
||||||
|
output.get("output_logprobs")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert input logprobs if present
|
||||||
|
input_logprobs_proto = self._convert_logprobs_to_proto(
|
||||||
|
output.get("input_logprobs")
|
||||||
|
)
|
||||||
|
|
||||||
return sglang_scheduler_pb2.GenerateResponse(
|
return sglang_scheduler_pb2.GenerateResponse(
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
complete=sglang_scheduler_pb2.GenerateComplete(
|
complete=sglang_scheduler_pb2.GenerateComplete(
|
||||||
@@ -529,6 +581,8 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
|||||||
"completion_tokens", len(output.get("token_ids", []))
|
"completion_tokens", len(output.get("token_ids", []))
|
||||||
),
|
),
|
||||||
cached_tokens=meta_info.get("cached_tokens", 0),
|
cached_tokens=meta_info.get("cached_tokens", 0),
|
||||||
|
output_logprobs=output_logprobs_proto,
|
||||||
|
input_logprobs=input_logprobs_proto,
|
||||||
**matched_stop_kwargs,
|
**matched_stop_kwargs,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -174,11 +174,14 @@ message GenerateStreamChunk {
|
|||||||
int32 completion_tokens = 3;
|
int32 completion_tokens = 3;
|
||||||
int32 cached_tokens = 4;
|
int32 cached_tokens = 4;
|
||||||
|
|
||||||
// Logprobs (if requested)
|
// Output logprobs (if requested) - incremental for streaming
|
||||||
LogProbs logprobs = 5;
|
LogProbs output_logprobs = 5;
|
||||||
|
|
||||||
// Hidden states (if requested)
|
// Hidden states (if requested)
|
||||||
repeated float hidden_states = 6;
|
repeated float hidden_states = 6;
|
||||||
|
|
||||||
|
// Input logprobs (if requested) - only in first chunk
|
||||||
|
LogProbs input_logprobs = 7;
|
||||||
}
|
}
|
||||||
|
|
||||||
message GenerateComplete {
|
message GenerateComplete {
|
||||||
@@ -193,8 +196,8 @@ message GenerateComplete {
|
|||||||
int32 completion_tokens = 4;
|
int32 completion_tokens = 4;
|
||||||
int32 cached_tokens = 5;
|
int32 cached_tokens = 5;
|
||||||
|
|
||||||
// All logprobs if requested
|
// Output logprobs if requested (cumulative)
|
||||||
repeated LogProbs all_logprobs = 6;
|
LogProbs output_logprobs = 6;
|
||||||
|
|
||||||
// All hidden states if requested
|
// All hidden states if requested
|
||||||
repeated HiddenStates all_hidden_states = 7;
|
repeated HiddenStates all_hidden_states = 7;
|
||||||
@@ -204,6 +207,9 @@ message GenerateComplete {
|
|||||||
uint32 matched_token_id = 8;
|
uint32 matched_token_id = 8;
|
||||||
string matched_stop_str = 9;
|
string matched_stop_str = 9;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Input logprobs if requested (for prompt tokens)
|
||||||
|
LogProbs input_logprobs = 10;
|
||||||
}
|
}
|
||||||
|
|
||||||
message GenerateError {
|
message GenerateError {
|
||||||
@@ -218,15 +224,11 @@ message LogProbs {
|
|||||||
|
|
||||||
// Top logprobs at each position
|
// Top logprobs at each position
|
||||||
repeated TopLogProbs top_logprobs = 3;
|
repeated TopLogProbs top_logprobs = 3;
|
||||||
|
|
||||||
// Decoded text for tokens
|
|
||||||
repeated string token_texts = 4;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
message TopLogProbs {
|
message TopLogProbs {
|
||||||
repeated float values = 1;
|
repeated float values = 1;
|
||||||
repeated int32 token_ids = 2;
|
repeated int32 token_ids = 2;
|
||||||
repeated string token_texts = 3;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
message HiddenStates {
|
message HiddenStates {
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -162,42 +162,46 @@ class GenerateResponse(_message.Message):
|
|||||||
def __init__(self, request_id: _Optional[str] = ..., chunk: _Optional[_Union[GenerateStreamChunk, _Mapping]] = ..., complete: _Optional[_Union[GenerateComplete, _Mapping]] = ..., error: _Optional[_Union[GenerateError, _Mapping]] = ...) -> None: ...
|
def __init__(self, request_id: _Optional[str] = ..., chunk: _Optional[_Union[GenerateStreamChunk, _Mapping]] = ..., complete: _Optional[_Union[GenerateComplete, _Mapping]] = ..., error: _Optional[_Union[GenerateError, _Mapping]] = ...) -> None: ...
|
||||||
|
|
||||||
class GenerateStreamChunk(_message.Message):
|
class GenerateStreamChunk(_message.Message):
|
||||||
__slots__ = ("token_ids", "prompt_tokens", "completion_tokens", "cached_tokens", "logprobs", "hidden_states")
|
__slots__ = ("token_ids", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "hidden_states", "input_logprobs")
|
||||||
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
|
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
|
||||||
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
||||||
COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
||||||
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
||||||
LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
OUTPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
||||||
HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
|
HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
||||||
token_ids: _containers.RepeatedScalarFieldContainer[int]
|
token_ids: _containers.RepeatedScalarFieldContainer[int]
|
||||||
prompt_tokens: int
|
prompt_tokens: int
|
||||||
completion_tokens: int
|
completion_tokens: int
|
||||||
cached_tokens: int
|
cached_tokens: int
|
||||||
logprobs: LogProbs
|
output_logprobs: LogProbs
|
||||||
hidden_states: _containers.RepeatedScalarFieldContainer[float]
|
hidden_states: _containers.RepeatedScalarFieldContainer[float]
|
||||||
def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ...) -> None: ...
|
input_logprobs: LogProbs
|
||||||
|
def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., input_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ...) -> None: ...
|
||||||
|
|
||||||
class GenerateComplete(_message.Message):
|
class GenerateComplete(_message.Message):
|
||||||
__slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "all_logprobs", "all_hidden_states", "matched_token_id", "matched_stop_str")
|
__slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "all_hidden_states", "matched_token_id", "matched_stop_str", "input_logprobs")
|
||||||
OUTPUT_IDS_FIELD_NUMBER: _ClassVar[int]
|
OUTPUT_IDS_FIELD_NUMBER: _ClassVar[int]
|
||||||
FINISH_REASON_FIELD_NUMBER: _ClassVar[int]
|
FINISH_REASON_FIELD_NUMBER: _ClassVar[int]
|
||||||
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
||||||
COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
||||||
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
||||||
ALL_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
OUTPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
||||||
ALL_HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
|
ALL_HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
|
||||||
MATCHED_TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
|
MATCHED_TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
|
||||||
MATCHED_STOP_STR_FIELD_NUMBER: _ClassVar[int]
|
MATCHED_STOP_STR_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
||||||
output_ids: _containers.RepeatedScalarFieldContainer[int]
|
output_ids: _containers.RepeatedScalarFieldContainer[int]
|
||||||
finish_reason: str
|
finish_reason: str
|
||||||
prompt_tokens: int
|
prompt_tokens: int
|
||||||
completion_tokens: int
|
completion_tokens: int
|
||||||
cached_tokens: int
|
cached_tokens: int
|
||||||
all_logprobs: _containers.RepeatedCompositeFieldContainer[LogProbs]
|
output_logprobs: LogProbs
|
||||||
all_hidden_states: _containers.RepeatedCompositeFieldContainer[HiddenStates]
|
all_hidden_states: _containers.RepeatedCompositeFieldContainer[HiddenStates]
|
||||||
matched_token_id: int
|
matched_token_id: int
|
||||||
matched_stop_str: str
|
matched_stop_str: str
|
||||||
def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., finish_reason: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., all_logprobs: _Optional[_Iterable[_Union[LogProbs, _Mapping]]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ..., matched_token_id: _Optional[int] = ..., matched_stop_str: _Optional[str] = ...) -> None: ...
|
input_logprobs: LogProbs
|
||||||
|
def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., finish_reason: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ..., matched_token_id: _Optional[int] = ..., matched_stop_str: _Optional[str] = ..., input_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ...) -> None: ...
|
||||||
|
|
||||||
class GenerateError(_message.Message):
|
class GenerateError(_message.Message):
|
||||||
__slots__ = ("message", "http_status_code", "details")
|
__slots__ = ("message", "http_status_code", "details")
|
||||||
@@ -210,26 +214,22 @@ class GenerateError(_message.Message):
|
|||||||
def __init__(self, message: _Optional[str] = ..., http_status_code: _Optional[str] = ..., details: _Optional[str] = ...) -> None: ...
|
def __init__(self, message: _Optional[str] = ..., http_status_code: _Optional[str] = ..., details: _Optional[str] = ...) -> None: ...
|
||||||
|
|
||||||
class LogProbs(_message.Message):
|
class LogProbs(_message.Message):
|
||||||
__slots__ = ("token_logprobs", "token_ids", "top_logprobs", "token_texts")
|
__slots__ = ("token_logprobs", "token_ids", "top_logprobs")
|
||||||
TOKEN_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
TOKEN_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
||||||
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
|
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
|
||||||
TOP_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
TOP_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
||||||
TOKEN_TEXTS_FIELD_NUMBER: _ClassVar[int]
|
|
||||||
token_logprobs: _containers.RepeatedScalarFieldContainer[float]
|
token_logprobs: _containers.RepeatedScalarFieldContainer[float]
|
||||||
token_ids: _containers.RepeatedScalarFieldContainer[int]
|
token_ids: _containers.RepeatedScalarFieldContainer[int]
|
||||||
top_logprobs: _containers.RepeatedCompositeFieldContainer[TopLogProbs]
|
top_logprobs: _containers.RepeatedCompositeFieldContainer[TopLogProbs]
|
||||||
token_texts: _containers.RepeatedScalarFieldContainer[str]
|
def __init__(self, token_logprobs: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., top_logprobs: _Optional[_Iterable[_Union[TopLogProbs, _Mapping]]] = ...) -> None: ...
|
||||||
def __init__(self, token_logprobs: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., top_logprobs: _Optional[_Iterable[_Union[TopLogProbs, _Mapping]]] = ..., token_texts: _Optional[_Iterable[str]] = ...) -> None: ...
|
|
||||||
|
|
||||||
class TopLogProbs(_message.Message):
|
class TopLogProbs(_message.Message):
|
||||||
__slots__ = ("values", "token_ids", "token_texts")
|
__slots__ = ("values", "token_ids")
|
||||||
VALUES_FIELD_NUMBER: _ClassVar[int]
|
VALUES_FIELD_NUMBER: _ClassVar[int]
|
||||||
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
|
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
|
||||||
TOKEN_TEXTS_FIELD_NUMBER: _ClassVar[int]
|
|
||||||
values: _containers.RepeatedScalarFieldContainer[float]
|
values: _containers.RepeatedScalarFieldContainer[float]
|
||||||
token_ids: _containers.RepeatedScalarFieldContainer[int]
|
token_ids: _containers.RepeatedScalarFieldContainer[int]
|
||||||
token_texts: _containers.RepeatedScalarFieldContainer[str]
|
def __init__(self, values: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ...) -> None: ...
|
||||||
def __init__(self, values: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., token_texts: _Optional[_Iterable[str]] = ...) -> None: ...
|
|
||||||
|
|
||||||
class HiddenStates(_message.Message):
|
class HiddenStates(_message.Message):
|
||||||
__slots__ = ("values", "layer", "position")
|
__slots__ = ("values", "layer", "position")
|
||||||
|
|||||||
@@ -174,11 +174,14 @@ message GenerateStreamChunk {
|
|||||||
int32 completion_tokens = 3;
|
int32 completion_tokens = 3;
|
||||||
int32 cached_tokens = 4;
|
int32 cached_tokens = 4;
|
||||||
|
|
||||||
// Logprobs (if requested)
|
// Output logprobs (if requested) - incremental for streaming
|
||||||
LogProbs logprobs = 5;
|
LogProbs output_logprobs = 5;
|
||||||
|
|
||||||
// Hidden states (if requested)
|
// Hidden states (if requested)
|
||||||
repeated float hidden_states = 6;
|
repeated float hidden_states = 6;
|
||||||
|
|
||||||
|
// Input logprobs (if requested) - only in first chunk
|
||||||
|
LogProbs input_logprobs = 7;
|
||||||
}
|
}
|
||||||
|
|
||||||
message GenerateComplete {
|
message GenerateComplete {
|
||||||
@@ -193,8 +196,8 @@ message GenerateComplete {
|
|||||||
int32 completion_tokens = 4;
|
int32 completion_tokens = 4;
|
||||||
int32 cached_tokens = 5;
|
int32 cached_tokens = 5;
|
||||||
|
|
||||||
// All logprobs if requested
|
// Output logprobs if requested (cumulative)
|
||||||
repeated LogProbs all_logprobs = 6;
|
LogProbs output_logprobs = 6;
|
||||||
|
|
||||||
// All hidden states if requested
|
// All hidden states if requested
|
||||||
repeated HiddenStates all_hidden_states = 7;
|
repeated HiddenStates all_hidden_states = 7;
|
||||||
@@ -204,6 +207,9 @@ message GenerateComplete {
|
|||||||
uint32 matched_token_id = 8;
|
uint32 matched_token_id = 8;
|
||||||
string matched_stop_str = 9;
|
string matched_stop_str = 9;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Input logprobs if requested (for prompt tokens)
|
||||||
|
LogProbs input_logprobs = 10;
|
||||||
}
|
}
|
||||||
|
|
||||||
message GenerateError {
|
message GenerateError {
|
||||||
@@ -218,15 +224,11 @@ message LogProbs {
|
|||||||
|
|
||||||
// Top logprobs at each position
|
// Top logprobs at each position
|
||||||
repeated TopLogProbs top_logprobs = 3;
|
repeated TopLogProbs top_logprobs = 3;
|
||||||
|
|
||||||
// Decoded text for tokens
|
|
||||||
repeated string token_texts = 4;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
message TopLogProbs {
|
message TopLogProbs {
|
||||||
repeated float values = 1;
|
repeated float values = 1;
|
||||||
repeated int32 token_ids = 2;
|
repeated int32 token_ids = 2;
|
||||||
repeated string token_texts = 3;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
message HiddenStates {
|
message HiddenStates {
|
||||||
|
|||||||
@@ -730,6 +730,73 @@ impl GrpcRouter {
|
|||||||
Json(response).into_response()
|
Json(response).into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Convert proto LogProbs to OpenAI ChatLogProbs format
|
||||||
|
/// Note: Always decodes with skip_special_tokens=false to show actual tokens generated
|
||||||
|
fn convert_proto_to_openai_logprobs(
|
||||||
|
&self,
|
||||||
|
proto_logprobs: &proto::LogProbs,
|
||||||
|
) -> Result<crate::protocols::spec::ChatLogProbs, String> {
|
||||||
|
let mut content_items = Vec::new();
|
||||||
|
|
||||||
|
// Decode token IDs to text (always with skip_special_tokens=false for logprobs)
|
||||||
|
let token_texts: Vec<String> = proto_logprobs
|
||||||
|
.token_ids
|
||||||
|
.iter()
|
||||||
|
.map(|&token_id| {
|
||||||
|
self.tokenizer
|
||||||
|
.decode(&[token_id as u32], false)
|
||||||
|
.unwrap_or_else(|_| format!("<token_{}>", token_id))
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Build ChatLogProbsContent for each token
|
||||||
|
for (i, &logprob) in proto_logprobs.token_logprobs.iter().enumerate() {
|
||||||
|
let token_text = token_texts.get(i).cloned().unwrap_or_default();
|
||||||
|
let bytes = Some(token_text.as_bytes().to_vec());
|
||||||
|
|
||||||
|
// Build top_logprobs for this position
|
||||||
|
let mut top_logprobs = Vec::new();
|
||||||
|
if let Some(top_logprobs_entry) = proto_logprobs.top_logprobs.get(i) {
|
||||||
|
// Decode top token IDs (always with skip_special_tokens=false)
|
||||||
|
let top_token_texts: Vec<String> = top_logprobs_entry
|
||||||
|
.token_ids
|
||||||
|
.iter()
|
||||||
|
.map(|&tid| {
|
||||||
|
self.tokenizer
|
||||||
|
.decode(&[tid as u32], false)
|
||||||
|
.unwrap_or_else(|_| format!("<token_{}>", tid))
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
for (j, (&top_logprob, &_top_token_id)) in top_logprobs_entry
|
||||||
|
.values
|
||||||
|
.iter()
|
||||||
|
.zip(top_logprobs_entry.token_ids.iter())
|
||||||
|
.enumerate()
|
||||||
|
{
|
||||||
|
if let Some(top_token_text) = top_token_texts.get(j) {
|
||||||
|
top_logprobs.push(crate::protocols::spec::TopLogProb {
|
||||||
|
token: top_token_text.clone(),
|
||||||
|
logprob: top_logprob,
|
||||||
|
bytes: Some(top_token_text.as_bytes().to_vec()),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
content_items.push(crate::protocols::spec::ChatLogProbsContent {
|
||||||
|
token: token_text,
|
||||||
|
logprob,
|
||||||
|
bytes,
|
||||||
|
top_logprobs,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(crate::protocols::spec::ChatLogProbs::Detailed {
|
||||||
|
content: (!content_items.is_empty()).then_some(content_items),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
/// Process a single GenerateComplete response into a ChatChoice
|
/// Process a single GenerateComplete response into a ChatChoice
|
||||||
async fn process_single_choice(
|
async fn process_single_choice(
|
||||||
&self,
|
&self,
|
||||||
@@ -855,7 +922,22 @@ impl GrpcRouter {
|
|||||||
None => None,
|
None => None,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Step 4: Build ChatCompletionMessage (proper response message type)
|
// Step 4: Convert output logprobs if present
|
||||||
|
// Note: complete.input_logprobs exists in proto but is not used for chat completions
|
||||||
|
// (input logprobs are only used in /v1/completions endpoint with echo=true)
|
||||||
|
let logprobs = if let Some(proto_logprobs) = &complete.output_logprobs {
|
||||||
|
match self.convert_proto_to_openai_logprobs(proto_logprobs) {
|
||||||
|
Ok(logprobs) => Some(logprobs),
|
||||||
|
Err(e) => {
|
||||||
|
error!("Failed to convert logprobs: {}", e);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
// Step 5: Build ChatCompletionMessage (proper response message type)
|
||||||
let chat_message = ChatCompletionMessage {
|
let chat_message = ChatCompletionMessage {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: if processed_text.is_empty() {
|
content: if processed_text.is_empty() {
|
||||||
@@ -867,11 +949,11 @@ impl GrpcRouter {
|
|||||||
reasoning_content: reasoning_text,
|
reasoning_content: reasoning_text,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Step 5: Build ChatChoice
|
// Step 6: Build ChatChoice
|
||||||
let choice = ChatChoice {
|
let choice = ChatChoice {
|
||||||
index: index as u32,
|
index: index as u32,
|
||||||
message: chat_message,
|
message: chat_message,
|
||||||
logprobs: None,
|
logprobs,
|
||||||
finish_reason: Some(final_finish_reason_str.to_string()),
|
finish_reason: Some(final_finish_reason_str.to_string()),
|
||||||
matched_stop,
|
matched_stop,
|
||||||
hidden_states: None,
|
hidden_states: None,
|
||||||
|
|||||||
Reference in New Issue
Block a user