Implement return_hidden_states for the OpenAI API (#6137)
This commit is contained in:
@@ -531,6 +531,7 @@ def v1_generate_request(
|
||||
logprob_start_lens = []
|
||||
top_logprobs_nums = []
|
||||
lora_paths = []
|
||||
return_hidden_states = []
|
||||
|
||||
for request in all_requests:
|
||||
# NOTE: with openai API, the prompt's logprobs are always not computed
|
||||
@@ -577,6 +578,7 @@ def v1_generate_request(
|
||||
top_logprobs_nums.append(
|
||||
request.logprobs if request.logprobs is not None else 0
|
||||
)
|
||||
return_hidden_states.append(request.return_hidden_states)
|
||||
|
||||
if len(all_requests) == 1:
|
||||
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
|
||||
@@ -588,6 +590,7 @@ def v1_generate_request(
|
||||
logprob_start_lens = logprob_start_lens[0]
|
||||
top_logprobs_nums = top_logprobs_nums[0]
|
||||
lora_paths = lora_paths[0]
|
||||
return_hidden_states = return_hidden_states[0]
|
||||
else:
|
||||
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
|
||||
prompt_kwargs = {"text": prompts}
|
||||
@@ -604,6 +607,7 @@ def v1_generate_request(
|
||||
stream=all_requests[0].stream,
|
||||
rid=request_ids,
|
||||
lora_path=lora_paths,
|
||||
return_hidden_states=return_hidden_states,
|
||||
)
|
||||
|
||||
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
|
||||
@@ -669,6 +673,17 @@ def v1_generate_response(
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
hidden_states = None
|
||||
if isinstance(request, list) and request[idx].return_hidden_states:
|
||||
hidden_states = ret_item["meta_info"].get("hidden_states", None)
|
||||
elif (not isinstance(request, list)) and request.return_hidden_states:
|
||||
hidden_states = ret_item["meta_info"].get("hidden_states", None)
|
||||
if hidden_states is not None:
|
||||
hidden_states = hidden_states[1:] # trim off the prefill
|
||||
hidden_states = (
|
||||
hidden_states[-1] if len(hidden_states) > 0 else []
|
||||
) # slice out the last token
|
||||
|
||||
finish_reason = ret_item["meta_info"]["finish_reason"]
|
||||
|
||||
if to_file:
|
||||
@@ -695,6 +710,7 @@ def v1_generate_response(
|
||||
if finish_reason and "matched" in finish_reason
|
||||
else None
|
||||
),
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
|
||||
choices.append(choice_data)
|
||||
@@ -719,6 +735,7 @@ def v1_generate_response(
|
||||
+ ret[i]["meta_info"]["completion_tokens"],
|
||||
},
|
||||
"system_fingerprint": None,
|
||||
"hidden_states": hidden_states,
|
||||
},
|
||||
}
|
||||
responses.append(response)
|
||||
@@ -763,6 +780,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
||||
prompt_tokens = {}
|
||||
completion_tokens = {}
|
||||
cached_tokens = {}
|
||||
hidden_states = None
|
||||
|
||||
try:
|
||||
async for content in tokenizer_manager.generate_request(
|
||||
@@ -777,6 +795,9 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
||||
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
|
||||
completion_tokens[index] = content["meta_info"]["completion_tokens"]
|
||||
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
|
||||
hidden_states = (
|
||||
content["meta_info"].get("hidden_states", None) or hidden_states
|
||||
)
|
||||
|
||||
if not stream_buffer: # The first chunk
|
||||
if request.echo:
|
||||
@@ -882,7 +903,25 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
||||
total_tokens=total_prompt_tokens + total_completion_tokens,
|
||||
prompt_tokens_details=prompt_tokens_details,
|
||||
)
|
||||
|
||||
if request.return_hidden_states and hidden_states:
|
||||
hidden_states = hidden_states[1:] # trim off the prefill
|
||||
hidden_states = (
|
||||
hidden_states[-1] if len(hidden_states) > 0 else []
|
||||
) # slice out the last token
|
||||
hidden_states_chunk = CompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
created=created,
|
||||
choices=[
|
||||
CompletionResponseStreamChoice(
|
||||
text="",
|
||||
index=index,
|
||||
hidden_states=hidden_states,
|
||||
finish_reason=None,
|
||||
)
|
||||
],
|
||||
model=request.model,
|
||||
)
|
||||
yield f"data: {hidden_states_chunk.model_dump_json()}\n\n"
|
||||
final_usage_chunk = CompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
created=created,
|
||||
@@ -959,6 +998,7 @@ def v1_chat_generate_request(
|
||||
top_logprobs_nums = []
|
||||
modalities_list = []
|
||||
lora_paths = []
|
||||
return_hidden_states = []
|
||||
|
||||
# NOTE: with openai API, the prompt's logprobs are always not computed
|
||||
|
||||
@@ -1176,6 +1216,7 @@ def v1_chat_generate_request(
|
||||
image_data_list.append(image_data)
|
||||
audio_data_list.append(audio_data)
|
||||
modalities_list.append(modalities)
|
||||
return_hidden_states.append(request.return_hidden_states)
|
||||
if len(all_requests) == 1:
|
||||
if is_multimodal:
|
||||
# processor will need text input
|
||||
@@ -1194,6 +1235,7 @@ def v1_chat_generate_request(
|
||||
modalities_list = modalities_list[0]
|
||||
lora_paths = lora_paths[0]
|
||||
request_ids = request_ids[0]
|
||||
return_hidden_states = return_hidden_states[0]
|
||||
else:
|
||||
if tokenizer_manager.model_config.is_multimodal:
|
||||
# processor will need text input
|
||||
@@ -1220,6 +1262,7 @@ def v1_chat_generate_request(
|
||||
bootstrap_host=all_requests[0].bootstrap_host,
|
||||
bootstrap_port=all_requests[0].bootstrap_port,
|
||||
bootstrap_room=all_requests[0].bootstrap_room,
|
||||
return_hidden_states=return_hidden_states,
|
||||
)
|
||||
|
||||
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
|
||||
@@ -1280,6 +1323,21 @@ def v1_chat_generate_response(
|
||||
else:
|
||||
choice_logprobs = None
|
||||
|
||||
if isinstance(request, list) and request[idx].return_hidden_states:
|
||||
include_hidden_states = True
|
||||
elif not isinstance(request, list) and request.return_hidden_states:
|
||||
include_hidden_states = True
|
||||
else:
|
||||
include_hidden_states = False
|
||||
if include_hidden_states and ret_item["meta_info"].get("hidden_states", None):
|
||||
hidden_states = ret_item["meta_info"]["hidden_states"]
|
||||
hidden_states = hidden_states[1:] # trim off the prefill
|
||||
hidden_states = (
|
||||
hidden_states[-1] if len(hidden_states) > 0 else []
|
||||
) # slice out the last token
|
||||
else:
|
||||
hidden_states = None
|
||||
|
||||
finish_reason = ret_item["meta_info"]["finish_reason"]
|
||||
|
||||
tool_calls = None
|
||||
@@ -1344,6 +1402,7 @@ def v1_chat_generate_response(
|
||||
"content": text if text else None,
|
||||
"tool_calls": tool_calls,
|
||||
"reasoning_content": reasoning_text if reasoning_text else None,
|
||||
"hidden_states": hidden_states,
|
||||
},
|
||||
"logprobs": choice_logprobs.model_dump() if choice_logprobs else None,
|
||||
"finish_reason": finish_reason["type"] if finish_reason else None,
|
||||
@@ -1369,6 +1428,7 @@ def v1_chat_generate_response(
|
||||
if finish_reason and "matched" in finish_reason
|
||||
else None
|
||||
),
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
|
||||
choices.append(choice_data)
|
||||
@@ -1437,6 +1497,7 @@ async def v1_chat_completions(
|
||||
if adapted_request.stream:
|
||||
parser_dict = {}
|
||||
reasoning_parser_dict = {}
|
||||
hidden_states = None
|
||||
|
||||
async def generate_stream_resp():
|
||||
tool_call_first = True
|
||||
@@ -1446,12 +1507,16 @@ async def v1_chat_completions(
|
||||
prompt_tokens = {}
|
||||
completion_tokens = {}
|
||||
cached_tokens = {}
|
||||
hidden_states = None
|
||||
try:
|
||||
async for content in tokenizer_manager.generate_request(
|
||||
adapted_request, raw_request
|
||||
):
|
||||
index = content.get("index", 0)
|
||||
text = content["text"]
|
||||
hidden_states = (
|
||||
content["meta_info"].get("hidden_states", None) or hidden_states
|
||||
)
|
||||
|
||||
is_first = is_firsts.get(index, True)
|
||||
stream_buffer = stream_buffers.get(index, "")
|
||||
@@ -1573,6 +1638,7 @@ async def v1_chat_completions(
|
||||
if (delta and len(delta) == 0) or not delta:
|
||||
stream_buffers[index] = new_stream_buffer
|
||||
is_firsts[index] = is_first
|
||||
n_prev_tokens[index] = n_prev_token
|
||||
continue
|
||||
|
||||
if request.tool_choice != "none" and request.tools:
|
||||
@@ -1661,6 +1727,7 @@ async def v1_chat_completions(
|
||||
|
||||
stream_buffers[index] = new_stream_buffer
|
||||
is_firsts[index] = is_first
|
||||
n_prev_tokens[index] = n_prev_token
|
||||
|
||||
else:
|
||||
# No tool calls => just treat this as normal text
|
||||
@@ -1693,6 +1760,7 @@ async def v1_chat_completions(
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
stream_buffers[index] = new_stream_buffer
|
||||
is_firsts[index] = is_first
|
||||
n_prev_tokens[index] = n_prev_token
|
||||
if finish_reason_type == "stop" and request.tool_choice != "none":
|
||||
parser = FunctionCallParser(
|
||||
tools=request.tools,
|
||||
@@ -1728,6 +1796,24 @@ async def v1_chat_completions(
|
||||
|
||||
else:
|
||||
usage = None
|
||||
if request.return_hidden_states and hidden_states:
|
||||
hidden_states = hidden_states[1:] # trim off the prefill
|
||||
hidden_states = (
|
||||
hidden_states[-1] if len(hidden_states) > 0 else []
|
||||
) # slice out the last token
|
||||
hidden_states_chunk = ChatCompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
created=created,
|
||||
choices=[
|
||||
ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
delta=DeltaMessage(hidden_states=hidden_states),
|
||||
finish_reason=finish_reason_type,
|
||||
)
|
||||
],
|
||||
model=request.model,
|
||||
)
|
||||
yield f"data: {hidden_states_chunk.model_dump_json()}\n\n"
|
||||
final_usage_chunk = ChatCompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
created=created,
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
import time
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, root_validator
|
||||
from pydantic import BaseModel, Field, model_serializer, root_validator
|
||||
from typing_extensions import Literal
|
||||
|
||||
|
||||
@@ -182,6 +182,7 @@ class CompletionRequest(BaseModel):
|
||||
skip_special_tokens: bool = True
|
||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||
session_params: Optional[Dict] = None
|
||||
return_hidden_states: Optional[bool] = False
|
||||
|
||||
|
||||
class CompletionResponseChoice(BaseModel):
|
||||
@@ -190,6 +191,11 @@ class CompletionResponseChoice(BaseModel):
|
||||
logprobs: Optional[LogProbs] = None
|
||||
finish_reason: Literal["stop", "length", "content_filter"]
|
||||
matched_stop: Union[None, int, str] = None
|
||||
hidden_states: Optional[object] = None
|
||||
|
||||
@model_serializer
|
||||
def _serialize(self):
|
||||
return exclude_if_none(self, ["hidden_states"])
|
||||
|
||||
|
||||
class CompletionResponse(BaseModel):
|
||||
@@ -207,6 +213,11 @@ class CompletionResponseStreamChoice(BaseModel):
|
||||
logprobs: Optional[LogProbs] = None
|
||||
finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None
|
||||
matched_stop: Union[None, int, str] = None
|
||||
hidden_states: Optional[object] = None
|
||||
|
||||
@model_serializer
|
||||
def _serialize(self):
|
||||
return exclude_if_none(self, ["hidden_states"])
|
||||
|
||||
|
||||
class CompletionStreamResponse(BaseModel):
|
||||
@@ -400,6 +411,9 @@ class ChatCompletionRequest(BaseModel):
|
||||
bootstrap_port: Optional[int] = None
|
||||
bootstrap_room: Optional[int] = None
|
||||
|
||||
# Hidden States
|
||||
return_hidden_states: Optional[bool] = False
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: Optional[str] = None
|
||||
@@ -416,6 +430,11 @@ class ChatCompletionResponseChoice(BaseModel):
|
||||
"stop", "length", "tool_calls", "content_filter", "function_call"
|
||||
]
|
||||
matched_stop: Union[None, int, str] = None
|
||||
hidden_states: Optional[object] = None
|
||||
|
||||
@model_serializer
|
||||
def _serialize(self):
|
||||
return exclude_if_none(self, ["hidden_states"])
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
@@ -432,6 +451,11 @@ class DeltaMessage(BaseModel):
|
||||
content: Optional[str] = None
|
||||
reasoning_content: Optional[str] = None
|
||||
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
|
||||
hidden_states: Optional[object] = None
|
||||
|
||||
@model_serializer
|
||||
def _serialize(self):
|
||||
return exclude_if_none(self, ["hidden_states"])
|
||||
|
||||
|
||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
@@ -484,3 +508,8 @@ class EmbeddingResponse(BaseModel):
|
||||
model: str
|
||||
object: str = "list"
|
||||
usage: Optional[UsageInfo] = None
|
||||
|
||||
|
||||
def exclude_if_none(obj, field_names: List[str]):
|
||||
omit_if_none_fields = {k for k, v in obj.model_fields.items() if k in field_names}
|
||||
return {k: v for k, v in obj if k not in omit_if_none_fields or v is not None}
|
||||
|
||||
Reference in New Issue
Block a user