[OAI Server Refactor] [ChatCompletions & Completions] Support Return Hidden State (#7329)

Signed-off-by: keru <rukeyang@gmail.com>
This commit is contained in:
Keyang Ru
2025-06-20 19:18:53 -07:00
committed by GitHub
parent 4d8d9b8efd
commit 5e7fdc79fa
5 changed files with 184 additions and 3 deletions

View File

@@ -19,7 +19,10 @@ from sglang.srt.entrypoints.openai.protocol import (
)
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
from sglang.srt.entrypoints.openai.usage_processor import UsageProcessor
from sglang.srt.entrypoints.openai.utils import to_openai_style_logprobs
from sglang.srt.entrypoints.openai.utils import (
process_hidden_states_from_ret,
to_openai_style_logprobs,
)
from sglang.srt.managers.io_struct import GenerateReqInput
logger = logging.getLogger(__name__)
@@ -76,6 +79,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
bootstrap_host=request.bootstrap_host,
bootstrap_port=request.bootstrap_port,
bootstrap_room=request.bootstrap_room,
return_hidden_states=request.return_hidden_states,
)
return adapted_request, request
@@ -188,6 +192,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
delta = text[len(stream_buffer) :]
stream_buffers[index] = stream_buffer + delta
finish_reason = content["meta_info"]["finish_reason"]
hidden_states = content["meta_info"].get("hidden_states", None)
choice_data = CompletionResponseStreamChoice(
index=index,
@@ -210,6 +215,30 @@ class OpenAIServingCompletion(OpenAIServingBase):
yield f"data: {chunk.model_dump_json()}\n\n"
if request.return_hidden_states and hidden_states:
for index, choice_hidden_states in hidden_states.items():
if choice_hidden_states:
last_token_hidden_states = (
choice_hidden_states[-1]
if len(choice_hidden_states) > 1
else []
)
hidden_states_chunk = CompletionStreamResponse(
id=content["meta_info"]["id"],
created=created,
object="text_completion",
choices=[
CompletionResponseStreamChoice(
index=index,
text="",
hidden_states=last_token_hidden_states,
finish_reason=None,
)
],
model=request.model,
)
yield f"data: {hidden_states_chunk.model_dump_json()}\n\n"
# Handle final usage chunk
if request.stream_options and request.stream_options.include_usage:
usage = UsageProcessor.calculate_streaming_usage(
@@ -304,6 +333,9 @@ class OpenAIServingCompletion(OpenAIServingBase):
output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
)
# Handle hidden states
hidden_states = process_hidden_states_from_ret(ret_item, request)
finish_reason = ret_item["meta_info"]["finish_reason"]
choice_data = CompletionResponseChoice(
@@ -316,6 +348,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
if finish_reason and "matched" in finish_reason
else None
),
hidden_states=hidden_states,
)
choices.append(choice_data)