[OAI Server Refactor] [ChatCompletions & Completions] Support Return Hidden State (#7329)
Signed-off-by: keru <rukeyang@gmail.com>
This commit is contained in:
@@ -30,6 +30,7 @@ from sglang.srt.entrypoints.openai.usage_processor import UsageProcessor
|
||||
from sglang.srt.entrypoints.openai.utils import (
|
||||
detect_template_content_format,
|
||||
process_content_for_template_format,
|
||||
process_hidden_states_from_ret,
|
||||
to_openai_style_logprobs,
|
||||
)
|
||||
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
||||
@@ -99,6 +100,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
bootstrap_host=request.bootstrap_host,
|
||||
bootstrap_port=request.bootstrap_port,
|
||||
bootstrap_room=request.bootstrap_room,
|
||||
return_hidden_states=request.return_hidden_states,
|
||||
)
|
||||
|
||||
return adapted_request, request
|
||||
@@ -402,6 +404,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
prompt_tokens = {}
|
||||
completion_tokens = {}
|
||||
cached_tokens = {}
|
||||
hidden_states = {}
|
||||
|
||||
try:
|
||||
async for content in self.tokenizer_manager.generate_request(
|
||||
@@ -412,6 +415,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
|
||||
completion_tokens[index] = content["meta_info"]["completion_tokens"]
|
||||
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
|
||||
hidden_states[index] = content["meta_info"].get("hidden_states", None)
|
||||
|
||||
# Handle logprobs
|
||||
choice_logprobs = None
|
||||
@@ -544,6 +548,31 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
)
|
||||
yield f"data: {finish_reason_chunk.model_dump_json()}\n\n"
|
||||
|
||||
# Send hidden states if requested
|
||||
if request.return_hidden_states and hidden_states:
|
||||
for index, choice_hidden_states in hidden_states.items():
|
||||
if choice_hidden_states:
|
||||
last_token_hidden_states = (
|
||||
choice_hidden_states[-1]
|
||||
if len(choice_hidden_states) > 1
|
||||
else []
|
||||
)
|
||||
hidden_states_chunk = ChatCompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
created=int(time.time()),
|
||||
choices=[
|
||||
ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
delta=DeltaMessage(
|
||||
hidden_states=last_token_hidden_states
|
||||
),
|
||||
finish_reason=finish_reason_type,
|
||||
)
|
||||
],
|
||||
model=request.model,
|
||||
)
|
||||
yield f"data: {hidden_states_chunk.model_dump_json()}\n\n"
|
||||
|
||||
# Additional usage chunk
|
||||
if request.stream_options and request.stream_options.include_usage:
|
||||
usage = UsageProcessor.calculate_streaming_usage(
|
||||
@@ -608,6 +637,9 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
if request.logprobs:
|
||||
choice_logprobs = self._process_response_logprobs(ret_item)
|
||||
|
||||
# Handle hidden states
|
||||
hidden_states = process_hidden_states_from_ret(ret_item, request)
|
||||
|
||||
finish_reason = ret_item["meta_info"]["finish_reason"]
|
||||
text = ret_item["text"]
|
||||
|
||||
@@ -654,6 +686,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
if finish_reason and "matched" in finish_reason
|
||||
else None
|
||||
),
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
choices.append(choice_data)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user