[OAI Server Refactor] [ChatCompletions & Completions] Support Return Hidden State (#7329)

Signed-off-by: keru <rukeyang@gmail.com>
This commit is contained in:
Keyang Ru
2025-06-20 19:18:53 -07:00
committed by GitHub
parent 4d8d9b8efd
commit 5e7fdc79fa
5 changed files with 184 additions and 3 deletions

View File

@@ -30,6 +30,7 @@ from sglang.srt.entrypoints.openai.usage_processor import UsageProcessor
from sglang.srt.entrypoints.openai.utils import (
detect_template_content_format,
process_content_for_template_format,
process_hidden_states_from_ret,
to_openai_style_logprobs,
)
from sglang.srt.function_call.function_call_parser import FunctionCallParser
@@ -99,6 +100,7 @@ class OpenAIServingChat(OpenAIServingBase):
bootstrap_host=request.bootstrap_host,
bootstrap_port=request.bootstrap_port,
bootstrap_room=request.bootstrap_room,
return_hidden_states=request.return_hidden_states,
)
return adapted_request, request
@@ -402,6 +404,7 @@ class OpenAIServingChat(OpenAIServingBase):
prompt_tokens = {}
completion_tokens = {}
cached_tokens = {}
hidden_states = {}
try:
async for content in self.tokenizer_manager.generate_request(
@@ -412,6 +415,7 @@ class OpenAIServingChat(OpenAIServingBase):
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
completion_tokens[index] = content["meta_info"]["completion_tokens"]
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
hidden_states[index] = content["meta_info"].get("hidden_states", None)
# Handle logprobs
choice_logprobs = None
@@ -544,6 +548,31 @@ class OpenAIServingChat(OpenAIServingBase):
)
yield f"data: {finish_reason_chunk.model_dump_json()}\n\n"
# Send hidden states if requested
if request.return_hidden_states and hidden_states:
for index, choice_hidden_states in hidden_states.items():
if choice_hidden_states:
last_token_hidden_states = (
choice_hidden_states[-1]
if len(choice_hidden_states) > 1
else []
)
hidden_states_chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=int(time.time()),
choices=[
ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(
hidden_states=last_token_hidden_states
),
finish_reason=finish_reason_type,
)
],
model=request.model,
)
yield f"data: {hidden_states_chunk.model_dump_json()}\n\n"
# Additional usage chunk
if request.stream_options and request.stream_options.include_usage:
usage = UsageProcessor.calculate_streaming_usage(
@@ -608,6 +637,9 @@ class OpenAIServingChat(OpenAIServingBase):
if request.logprobs:
choice_logprobs = self._process_response_logprobs(ret_item)
# Handle hidden states
hidden_states = process_hidden_states_from_ret(ret_item, request)
finish_reason = ret_item["meta_info"]["finish_reason"]
text = ret_item["text"]
@@ -654,6 +686,7 @@ class OpenAIServingChat(OpenAIServingBase):
if finish_reason and "matched" in finish_reason
else None
),
hidden_states=hidden_states,
)
choices.append(choice_data)