[OAI Server Refactor] [ChatCompletions & Completions] Support Return Hidden State (#7329)
Signed-off-by: keru <rukeyang@gmail.com>
This commit is contained in:
@@ -19,7 +19,10 @@ from sglang.srt.entrypoints.openai.protocol import (
|
||||
)
|
||||
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
|
||||
from sglang.srt.entrypoints.openai.usage_processor import UsageProcessor
|
||||
from sglang.srt.entrypoints.openai.utils import to_openai_style_logprobs
|
||||
from sglang.srt.entrypoints.openai.utils import (
|
||||
process_hidden_states_from_ret,
|
||||
to_openai_style_logprobs,
|
||||
)
|
||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -76,6 +79,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
bootstrap_host=request.bootstrap_host,
|
||||
bootstrap_port=request.bootstrap_port,
|
||||
bootstrap_room=request.bootstrap_room,
|
||||
return_hidden_states=request.return_hidden_states,
|
||||
)
|
||||
|
||||
return adapted_request, request
|
||||
@@ -188,6 +192,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
delta = text[len(stream_buffer) :]
|
||||
stream_buffers[index] = stream_buffer + delta
|
||||
finish_reason = content["meta_info"]["finish_reason"]
|
||||
hidden_states = content["meta_info"].get("hidden_states", None)
|
||||
|
||||
choice_data = CompletionResponseStreamChoice(
|
||||
index=index,
|
||||
@@ -210,6 +215,30 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
|
||||
if request.return_hidden_states and hidden_states:
|
||||
for index, choice_hidden_states in hidden_states.items():
|
||||
if choice_hidden_states:
|
||||
last_token_hidden_states = (
|
||||
choice_hidden_states[-1]
|
||||
if len(choice_hidden_states) > 1
|
||||
else []
|
||||
)
|
||||
hidden_states_chunk = CompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
created=created,
|
||||
object="text_completion",
|
||||
choices=[
|
||||
CompletionResponseStreamChoice(
|
||||
index=index,
|
||||
text="",
|
||||
hidden_states=last_token_hidden_states,
|
||||
finish_reason=None,
|
||||
)
|
||||
],
|
||||
model=request.model,
|
||||
)
|
||||
yield f"data: {hidden_states_chunk.model_dump_json()}\n\n"
|
||||
|
||||
# Handle final usage chunk
|
||||
if request.stream_options and request.stream_options.include_usage:
|
||||
usage = UsageProcessor.calculate_streaming_usage(
|
||||
@@ -304,6 +333,9 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
|
||||
)
|
||||
|
||||
# Handle hidden states
|
||||
hidden_states = process_hidden_states_from_ret(ret_item, request)
|
||||
|
||||
finish_reason = ret_item["meta_info"]["finish_reason"]
|
||||
|
||||
choice_data = CompletionResponseChoice(
|
||||
@@ -316,6 +348,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
if finish_reason and "matched" in finish_reason
|
||||
else None
|
||||
),
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
choices.append(choice_data)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user