[OAI Server Refactor] [ChatCompletions & Completions] Support Return Hidden State (#7329)
Signed-off-by: keru <rukeyang@gmail.com>
This commit is contained in:
@@ -16,7 +16,13 @@
|
||||
import time
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
field_validator,
|
||||
model_serializer,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import Literal
|
||||
|
||||
|
||||
@@ -167,6 +173,7 @@ class CompletionRequest(BaseModel):
|
||||
temperature: float = 1.0
|
||||
top_p: float = 1.0
|
||||
user: Optional[str] = None
|
||||
return_hidden_states: bool = False
|
||||
|
||||
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
||||
top_k: int = -1
|
||||
@@ -202,6 +209,14 @@ class CompletionResponseChoice(BaseModel):
|
||||
logprobs: Optional[LogProbs] = None
|
||||
finish_reason: Literal["stop", "length", "content_filter", "abort"]
|
||||
matched_stop: Union[None, int, str] = None
|
||||
hidden_states: Optional[object] = None
|
||||
|
||||
@model_serializer(mode="wrap")
|
||||
def _serialize(self, handler):
|
||||
data = handler(self)
|
||||
if self.hidden_states is None:
|
||||
data.pop("hidden_states", None)
|
||||
return data
|
||||
|
||||
|
||||
class CompletionResponse(BaseModel):
|
||||
@@ -219,6 +234,14 @@ class CompletionResponseStreamChoice(BaseModel):
|
||||
logprobs: Optional[LogProbs] = None
|
||||
finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None
|
||||
matched_stop: Union[None, int, str] = None
|
||||
hidden_states: Optional[object] = None
|
||||
|
||||
@model_serializer(mode="wrap")
|
||||
def _serialize(self, handler):
|
||||
data = handler(self)
|
||||
if self.hidden_states is None:
|
||||
data.pop("hidden_states", None)
|
||||
return data
|
||||
|
||||
|
||||
class CompletionStreamResponse(BaseModel):
|
||||
@@ -376,6 +399,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = Field(
|
||||
default="auto", examples=["none"]
|
||||
) # noqa
|
||||
return_hidden_states: bool = False
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@@ -437,6 +461,14 @@ class ChatCompletionResponseChoice(BaseModel):
|
||||
"stop", "length", "tool_calls", "content_filter", "function_call", "abort"
|
||||
]
|
||||
matched_stop: Union[None, int, str] = None
|
||||
hidden_states: Optional[object] = None
|
||||
|
||||
@model_serializer(mode="wrap")
|
||||
def _serialize(self, handler):
|
||||
data = handler(self)
|
||||
if self.hidden_states is None:
|
||||
data.pop("hidden_states", None)
|
||||
return data
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
@@ -453,6 +485,14 @@ class DeltaMessage(BaseModel):
|
||||
content: Optional[str] = None
|
||||
reasoning_content: Optional[str] = None
|
||||
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
|
||||
hidden_states: Optional[object] = None
|
||||
|
||||
@model_serializer(mode="wrap")
|
||||
def _serialize(self, handler):
|
||||
data = handler(self)
|
||||
if self.hidden_states is None:
|
||||
data.pop("hidden_states", None)
|
||||
return data
|
||||
|
||||
|
||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
|
||||
@@ -30,6 +30,7 @@ from sglang.srt.entrypoints.openai.usage_processor import UsageProcessor
|
||||
from sglang.srt.entrypoints.openai.utils import (
|
||||
detect_template_content_format,
|
||||
process_content_for_template_format,
|
||||
process_hidden_states_from_ret,
|
||||
to_openai_style_logprobs,
|
||||
)
|
||||
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
||||
@@ -99,6 +100,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
bootstrap_host=request.bootstrap_host,
|
||||
bootstrap_port=request.bootstrap_port,
|
||||
bootstrap_room=request.bootstrap_room,
|
||||
return_hidden_states=request.return_hidden_states,
|
||||
)
|
||||
|
||||
return adapted_request, request
|
||||
@@ -402,6 +404,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
prompt_tokens = {}
|
||||
completion_tokens = {}
|
||||
cached_tokens = {}
|
||||
hidden_states = {}
|
||||
|
||||
try:
|
||||
async for content in self.tokenizer_manager.generate_request(
|
||||
@@ -412,6 +415,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
|
||||
completion_tokens[index] = content["meta_info"]["completion_tokens"]
|
||||
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
|
||||
hidden_states[index] = content["meta_info"].get("hidden_states", None)
|
||||
|
||||
# Handle logprobs
|
||||
choice_logprobs = None
|
||||
@@ -544,6 +548,31 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
)
|
||||
yield f"data: {finish_reason_chunk.model_dump_json()}\n\n"
|
||||
|
||||
# Send hidden states if requested
|
||||
if request.return_hidden_states and hidden_states:
|
||||
for index, choice_hidden_states in hidden_states.items():
|
||||
if choice_hidden_states:
|
||||
last_token_hidden_states = (
|
||||
choice_hidden_states[-1]
|
||||
if len(choice_hidden_states) > 1
|
||||
else []
|
||||
)
|
||||
hidden_states_chunk = ChatCompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
created=int(time.time()),
|
||||
choices=[
|
||||
ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
delta=DeltaMessage(
|
||||
hidden_states=last_token_hidden_states
|
||||
),
|
||||
finish_reason=finish_reason_type,
|
||||
)
|
||||
],
|
||||
model=request.model,
|
||||
)
|
||||
yield f"data: {hidden_states_chunk.model_dump_json()}\n\n"
|
||||
|
||||
# Additional usage chunk
|
||||
if request.stream_options and request.stream_options.include_usage:
|
||||
usage = UsageProcessor.calculate_streaming_usage(
|
||||
@@ -608,6 +637,9 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
if request.logprobs:
|
||||
choice_logprobs = self._process_response_logprobs(ret_item)
|
||||
|
||||
# Handle hidden states
|
||||
hidden_states = process_hidden_states_from_ret(ret_item, request)
|
||||
|
||||
finish_reason = ret_item["meta_info"]["finish_reason"]
|
||||
text = ret_item["text"]
|
||||
|
||||
@@ -654,6 +686,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
if finish_reason and "matched" in finish_reason
|
||||
else None
|
||||
),
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
choices.append(choice_data)
|
||||
|
||||
|
||||
@@ -19,7 +19,10 @@ from sglang.srt.entrypoints.openai.protocol import (
|
||||
)
|
||||
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
|
||||
from sglang.srt.entrypoints.openai.usage_processor import UsageProcessor
|
||||
from sglang.srt.entrypoints.openai.utils import to_openai_style_logprobs
|
||||
from sglang.srt.entrypoints.openai.utils import (
|
||||
process_hidden_states_from_ret,
|
||||
to_openai_style_logprobs,
|
||||
)
|
||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -76,6 +79,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
bootstrap_host=request.bootstrap_host,
|
||||
bootstrap_port=request.bootstrap_port,
|
||||
bootstrap_room=request.bootstrap_room,
|
||||
return_hidden_states=request.return_hidden_states,
|
||||
)
|
||||
|
||||
return adapted_request, request
|
||||
@@ -188,6 +192,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
delta = text[len(stream_buffer) :]
|
||||
stream_buffers[index] = stream_buffer + delta
|
||||
finish_reason = content["meta_info"]["finish_reason"]
|
||||
hidden_states = content["meta_info"].get("hidden_states", None)
|
||||
|
||||
choice_data = CompletionResponseStreamChoice(
|
||||
index=index,
|
||||
@@ -210,6 +215,30 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
|
||||
if request.return_hidden_states and hidden_states:
|
||||
for index, choice_hidden_states in hidden_states.items():
|
||||
if choice_hidden_states:
|
||||
last_token_hidden_states = (
|
||||
choice_hidden_states[-1]
|
||||
if len(choice_hidden_states) > 1
|
||||
else []
|
||||
)
|
||||
hidden_states_chunk = CompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
created=created,
|
||||
object="text_completion",
|
||||
choices=[
|
||||
CompletionResponseStreamChoice(
|
||||
index=index,
|
||||
text="",
|
||||
hidden_states=last_token_hidden_states,
|
||||
finish_reason=None,
|
||||
)
|
||||
],
|
||||
model=request.model,
|
||||
)
|
||||
yield f"data: {hidden_states_chunk.model_dump_json()}\n\n"
|
||||
|
||||
# Handle final usage chunk
|
||||
if request.stream_options and request.stream_options.include_usage:
|
||||
usage = UsageProcessor.calculate_streaming_usage(
|
||||
@@ -304,6 +333,9 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
|
||||
)
|
||||
|
||||
# Handle hidden states
|
||||
hidden_states = process_hidden_states_from_ret(ret_item, request)
|
||||
|
||||
finish_reason = ret_item["meta_info"]["finish_reason"]
|
||||
|
||||
choice_data = CompletionResponseChoice(
|
||||
@@ -316,6 +348,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
if finish_reason and "matched" in finish_reason
|
||||
else None
|
||||
),
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
choices.append(choice_data)
|
||||
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import jinja2.nodes
|
||||
import transformers.utils.chat_template_utils as hf_chat_utils
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import LogProbs
|
||||
from sglang.srt.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
CompletionRequest,
|
||||
LogProbs,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -205,3 +210,28 @@ def to_openai_style_logprobs(
|
||||
append_top_logprobs(output_top_logprobs)
|
||||
|
||||
return ret_logprobs
|
||||
|
||||
|
||||
def process_hidden_states_from_ret(
|
||||
ret_item: Dict[str, Any],
|
||||
request: Union[
|
||||
ChatCompletionRequest,
|
||||
CompletionRequest,
|
||||
],
|
||||
) -> Optional[List]:
|
||||
"""Process hidden states from a ret item in non-streaming response.
|
||||
|
||||
Args:
|
||||
ret_item: Response item containing meta_info
|
||||
request: The original request object
|
||||
|
||||
Returns:
|
||||
Processed hidden states for the last token, or None
|
||||
"""
|
||||
if not request.return_hidden_states:
|
||||
return None
|
||||
|
||||
hidden_states = ret_item["meta_info"].get("hidden_states", None)
|
||||
if hidden_states is not None:
|
||||
hidden_states = hidden_states[-1] if len(hidden_states) > 1 else []
|
||||
return hidden_states
|
||||
|
||||
Reference in New Issue
Block a user