Implement return_hidden_states for the OpenAI API (#6137)
This commit is contained in:
@@ -531,6 +531,7 @@ def v1_generate_request(
|
|||||||
logprob_start_lens = []
|
logprob_start_lens = []
|
||||||
top_logprobs_nums = []
|
top_logprobs_nums = []
|
||||||
lora_paths = []
|
lora_paths = []
|
||||||
|
return_hidden_states = []
|
||||||
|
|
||||||
for request in all_requests:
|
for request in all_requests:
|
||||||
# NOTE: with openai API, the prompt's logprobs are always not computed
|
# NOTE: with openai API, the prompt's logprobs are always not computed
|
||||||
@@ -577,6 +578,7 @@ def v1_generate_request(
|
|||||||
top_logprobs_nums.append(
|
top_logprobs_nums.append(
|
||||||
request.logprobs if request.logprobs is not None else 0
|
request.logprobs if request.logprobs is not None else 0
|
||||||
)
|
)
|
||||||
|
return_hidden_states.append(request.return_hidden_states)
|
||||||
|
|
||||||
if len(all_requests) == 1:
|
if len(all_requests) == 1:
|
||||||
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
|
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
|
||||||
@@ -588,6 +590,7 @@ def v1_generate_request(
|
|||||||
logprob_start_lens = logprob_start_lens[0]
|
logprob_start_lens = logprob_start_lens[0]
|
||||||
top_logprobs_nums = top_logprobs_nums[0]
|
top_logprobs_nums = top_logprobs_nums[0]
|
||||||
lora_paths = lora_paths[0]
|
lora_paths = lora_paths[0]
|
||||||
|
return_hidden_states = return_hidden_states[0]
|
||||||
else:
|
else:
|
||||||
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
|
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
|
||||||
prompt_kwargs = {"text": prompts}
|
prompt_kwargs = {"text": prompts}
|
||||||
@@ -604,6 +607,7 @@ def v1_generate_request(
|
|||||||
stream=all_requests[0].stream,
|
stream=all_requests[0].stream,
|
||||||
rid=request_ids,
|
rid=request_ids,
|
||||||
lora_path=lora_paths,
|
lora_path=lora_paths,
|
||||||
|
return_hidden_states=return_hidden_states,
|
||||||
)
|
)
|
||||||
|
|
||||||
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
|
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
|
||||||
@@ -669,6 +673,17 @@ def v1_generate_response(
|
|||||||
else:
|
else:
|
||||||
logprobs = None
|
logprobs = None
|
||||||
|
|
||||||
|
hidden_states = None
|
||||||
|
if isinstance(request, list) and request[idx].return_hidden_states:
|
||||||
|
hidden_states = ret_item["meta_info"].get("hidden_states", None)
|
||||||
|
elif (not isinstance(request, list)) and request.return_hidden_states:
|
||||||
|
hidden_states = ret_item["meta_info"].get("hidden_states", None)
|
||||||
|
if hidden_states is not None:
|
||||||
|
hidden_states = hidden_states[1:] # trim off the prefill
|
||||||
|
hidden_states = (
|
||||||
|
hidden_states[-1] if len(hidden_states) > 0 else []
|
||||||
|
) # slice out the last token
|
||||||
|
|
||||||
finish_reason = ret_item["meta_info"]["finish_reason"]
|
finish_reason = ret_item["meta_info"]["finish_reason"]
|
||||||
|
|
||||||
if to_file:
|
if to_file:
|
||||||
@@ -695,6 +710,7 @@ def v1_generate_response(
|
|||||||
if finish_reason and "matched" in finish_reason
|
if finish_reason and "matched" in finish_reason
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
|
hidden_states=hidden_states,
|
||||||
)
|
)
|
||||||
|
|
||||||
choices.append(choice_data)
|
choices.append(choice_data)
|
||||||
@@ -719,6 +735,7 @@ def v1_generate_response(
|
|||||||
+ ret[i]["meta_info"]["completion_tokens"],
|
+ ret[i]["meta_info"]["completion_tokens"],
|
||||||
},
|
},
|
||||||
"system_fingerprint": None,
|
"system_fingerprint": None,
|
||||||
|
"hidden_states": hidden_states,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
responses.append(response)
|
responses.append(response)
|
||||||
@@ -763,6 +780,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|||||||
prompt_tokens = {}
|
prompt_tokens = {}
|
||||||
completion_tokens = {}
|
completion_tokens = {}
|
||||||
cached_tokens = {}
|
cached_tokens = {}
|
||||||
|
hidden_states = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for content in tokenizer_manager.generate_request(
|
async for content in tokenizer_manager.generate_request(
|
||||||
@@ -777,6 +795,9 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|||||||
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
|
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
|
||||||
completion_tokens[index] = content["meta_info"]["completion_tokens"]
|
completion_tokens[index] = content["meta_info"]["completion_tokens"]
|
||||||
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
|
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
|
||||||
|
hidden_states = (
|
||||||
|
content["meta_info"].get("hidden_states", None) or hidden_states
|
||||||
|
)
|
||||||
|
|
||||||
if not stream_buffer: # The first chunk
|
if not stream_buffer: # The first chunk
|
||||||
if request.echo:
|
if request.echo:
|
||||||
@@ -882,7 +903,25 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|||||||
total_tokens=total_prompt_tokens + total_completion_tokens,
|
total_tokens=total_prompt_tokens + total_completion_tokens,
|
||||||
prompt_tokens_details=prompt_tokens_details,
|
prompt_tokens_details=prompt_tokens_details,
|
||||||
)
|
)
|
||||||
|
if request.return_hidden_states and hidden_states:
|
||||||
|
hidden_states = hidden_states[1:] # trim off the prefill
|
||||||
|
hidden_states = (
|
||||||
|
hidden_states[-1] if len(hidden_states) > 0 else []
|
||||||
|
) # slice out the last token
|
||||||
|
hidden_states_chunk = CompletionStreamResponse(
|
||||||
|
id=content["meta_info"]["id"],
|
||||||
|
created=created,
|
||||||
|
choices=[
|
||||||
|
CompletionResponseStreamChoice(
|
||||||
|
text="",
|
||||||
|
index=index,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
finish_reason=None,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
model=request.model,
|
||||||
|
)
|
||||||
|
yield f"data: {hidden_states_chunk.model_dump_json()}\n\n"
|
||||||
final_usage_chunk = CompletionStreamResponse(
|
final_usage_chunk = CompletionStreamResponse(
|
||||||
id=content["meta_info"]["id"],
|
id=content["meta_info"]["id"],
|
||||||
created=created,
|
created=created,
|
||||||
@@ -959,6 +998,7 @@ def v1_chat_generate_request(
|
|||||||
top_logprobs_nums = []
|
top_logprobs_nums = []
|
||||||
modalities_list = []
|
modalities_list = []
|
||||||
lora_paths = []
|
lora_paths = []
|
||||||
|
return_hidden_states = []
|
||||||
|
|
||||||
# NOTE: with openai API, the prompt's logprobs are always not computed
|
# NOTE: with openai API, the prompt's logprobs are always not computed
|
||||||
|
|
||||||
@@ -1176,6 +1216,7 @@ def v1_chat_generate_request(
|
|||||||
image_data_list.append(image_data)
|
image_data_list.append(image_data)
|
||||||
audio_data_list.append(audio_data)
|
audio_data_list.append(audio_data)
|
||||||
modalities_list.append(modalities)
|
modalities_list.append(modalities)
|
||||||
|
return_hidden_states.append(request.return_hidden_states)
|
||||||
if len(all_requests) == 1:
|
if len(all_requests) == 1:
|
||||||
if is_multimodal:
|
if is_multimodal:
|
||||||
# processor will need text input
|
# processor will need text input
|
||||||
@@ -1194,6 +1235,7 @@ def v1_chat_generate_request(
|
|||||||
modalities_list = modalities_list[0]
|
modalities_list = modalities_list[0]
|
||||||
lora_paths = lora_paths[0]
|
lora_paths = lora_paths[0]
|
||||||
request_ids = request_ids[0]
|
request_ids = request_ids[0]
|
||||||
|
return_hidden_states = return_hidden_states[0]
|
||||||
else:
|
else:
|
||||||
if tokenizer_manager.model_config.is_multimodal:
|
if tokenizer_manager.model_config.is_multimodal:
|
||||||
# processor will need text input
|
# processor will need text input
|
||||||
@@ -1220,6 +1262,7 @@ def v1_chat_generate_request(
|
|||||||
bootstrap_host=all_requests[0].bootstrap_host,
|
bootstrap_host=all_requests[0].bootstrap_host,
|
||||||
bootstrap_port=all_requests[0].bootstrap_port,
|
bootstrap_port=all_requests[0].bootstrap_port,
|
||||||
bootstrap_room=all_requests[0].bootstrap_room,
|
bootstrap_room=all_requests[0].bootstrap_room,
|
||||||
|
return_hidden_states=return_hidden_states,
|
||||||
)
|
)
|
||||||
|
|
||||||
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
|
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
|
||||||
@@ -1280,6 +1323,21 @@ def v1_chat_generate_response(
|
|||||||
else:
|
else:
|
||||||
choice_logprobs = None
|
choice_logprobs = None
|
||||||
|
|
||||||
|
if isinstance(request, list) and request[idx].return_hidden_states:
|
||||||
|
include_hidden_states = True
|
||||||
|
elif not isinstance(request, list) and request.return_hidden_states:
|
||||||
|
include_hidden_states = True
|
||||||
|
else:
|
||||||
|
include_hidden_states = False
|
||||||
|
if include_hidden_states and ret_item["meta_info"].get("hidden_states", None):
|
||||||
|
hidden_states = ret_item["meta_info"]["hidden_states"]
|
||||||
|
hidden_states = hidden_states[1:] # trim off the prefill
|
||||||
|
hidden_states = (
|
||||||
|
hidden_states[-1] if len(hidden_states) > 0 else []
|
||||||
|
) # slice out the last token
|
||||||
|
else:
|
||||||
|
hidden_states = None
|
||||||
|
|
||||||
finish_reason = ret_item["meta_info"]["finish_reason"]
|
finish_reason = ret_item["meta_info"]["finish_reason"]
|
||||||
|
|
||||||
tool_calls = None
|
tool_calls = None
|
||||||
@@ -1344,6 +1402,7 @@ def v1_chat_generate_response(
|
|||||||
"content": text if text else None,
|
"content": text if text else None,
|
||||||
"tool_calls": tool_calls,
|
"tool_calls": tool_calls,
|
||||||
"reasoning_content": reasoning_text if reasoning_text else None,
|
"reasoning_content": reasoning_text if reasoning_text else None,
|
||||||
|
"hidden_states": hidden_states,
|
||||||
},
|
},
|
||||||
"logprobs": choice_logprobs.model_dump() if choice_logprobs else None,
|
"logprobs": choice_logprobs.model_dump() if choice_logprobs else None,
|
||||||
"finish_reason": finish_reason["type"] if finish_reason else None,
|
"finish_reason": finish_reason["type"] if finish_reason else None,
|
||||||
@@ -1369,6 +1428,7 @@ def v1_chat_generate_response(
|
|||||||
if finish_reason and "matched" in finish_reason
|
if finish_reason and "matched" in finish_reason
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
|
hidden_states=hidden_states,
|
||||||
)
|
)
|
||||||
|
|
||||||
choices.append(choice_data)
|
choices.append(choice_data)
|
||||||
@@ -1437,6 +1497,7 @@ async def v1_chat_completions(
|
|||||||
if adapted_request.stream:
|
if adapted_request.stream:
|
||||||
parser_dict = {}
|
parser_dict = {}
|
||||||
reasoning_parser_dict = {}
|
reasoning_parser_dict = {}
|
||||||
|
hidden_states = None
|
||||||
|
|
||||||
async def generate_stream_resp():
|
async def generate_stream_resp():
|
||||||
tool_call_first = True
|
tool_call_first = True
|
||||||
@@ -1446,12 +1507,16 @@ async def v1_chat_completions(
|
|||||||
prompt_tokens = {}
|
prompt_tokens = {}
|
||||||
completion_tokens = {}
|
completion_tokens = {}
|
||||||
cached_tokens = {}
|
cached_tokens = {}
|
||||||
|
hidden_states = None
|
||||||
try:
|
try:
|
||||||
async for content in tokenizer_manager.generate_request(
|
async for content in tokenizer_manager.generate_request(
|
||||||
adapted_request, raw_request
|
adapted_request, raw_request
|
||||||
):
|
):
|
||||||
index = content.get("index", 0)
|
index = content.get("index", 0)
|
||||||
text = content["text"]
|
text = content["text"]
|
||||||
|
hidden_states = (
|
||||||
|
content["meta_info"].get("hidden_states", None) or hidden_states
|
||||||
|
)
|
||||||
|
|
||||||
is_first = is_firsts.get(index, True)
|
is_first = is_firsts.get(index, True)
|
||||||
stream_buffer = stream_buffers.get(index, "")
|
stream_buffer = stream_buffers.get(index, "")
|
||||||
@@ -1573,6 +1638,7 @@ async def v1_chat_completions(
|
|||||||
if (delta and len(delta) == 0) or not delta:
|
if (delta and len(delta) == 0) or not delta:
|
||||||
stream_buffers[index] = new_stream_buffer
|
stream_buffers[index] = new_stream_buffer
|
||||||
is_firsts[index] = is_first
|
is_firsts[index] = is_first
|
||||||
|
n_prev_tokens[index] = n_prev_token
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if request.tool_choice != "none" and request.tools:
|
if request.tool_choice != "none" and request.tools:
|
||||||
@@ -1661,6 +1727,7 @@ async def v1_chat_completions(
|
|||||||
|
|
||||||
stream_buffers[index] = new_stream_buffer
|
stream_buffers[index] = new_stream_buffer
|
||||||
is_firsts[index] = is_first
|
is_firsts[index] = is_first
|
||||||
|
n_prev_tokens[index] = n_prev_token
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# No tool calls => just treat this as normal text
|
# No tool calls => just treat this as normal text
|
||||||
@@ -1693,6 +1760,7 @@ async def v1_chat_completions(
|
|||||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||||
stream_buffers[index] = new_stream_buffer
|
stream_buffers[index] = new_stream_buffer
|
||||||
is_firsts[index] = is_first
|
is_firsts[index] = is_first
|
||||||
|
n_prev_tokens[index] = n_prev_token
|
||||||
if finish_reason_type == "stop" and request.tool_choice != "none":
|
if finish_reason_type == "stop" and request.tool_choice != "none":
|
||||||
parser = FunctionCallParser(
|
parser = FunctionCallParser(
|
||||||
tools=request.tools,
|
tools=request.tools,
|
||||||
@@ -1728,6 +1796,24 @@ async def v1_chat_completions(
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
usage = None
|
usage = None
|
||||||
|
if request.return_hidden_states and hidden_states:
|
||||||
|
hidden_states = hidden_states[1:] # trim off the prefill
|
||||||
|
hidden_states = (
|
||||||
|
hidden_states[-1] if len(hidden_states) > 0 else []
|
||||||
|
) # slice out the last token
|
||||||
|
hidden_states_chunk = ChatCompletionStreamResponse(
|
||||||
|
id=content["meta_info"]["id"],
|
||||||
|
created=created,
|
||||||
|
choices=[
|
||||||
|
ChatCompletionResponseStreamChoice(
|
||||||
|
index=index,
|
||||||
|
delta=DeltaMessage(hidden_states=hidden_states),
|
||||||
|
finish_reason=finish_reason_type,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
model=request.model,
|
||||||
|
)
|
||||||
|
yield f"data: {hidden_states_chunk.model_dump_json()}\n\n"
|
||||||
final_usage_chunk = ChatCompletionStreamResponse(
|
final_usage_chunk = ChatCompletionStreamResponse(
|
||||||
id=content["meta_info"]["id"],
|
id=content["meta_info"]["id"],
|
||||||
created=created,
|
created=created,
|
||||||
|
|||||||
@@ -16,7 +16,7 @@
|
|||||||
import time
|
import time
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, root_validator
|
from pydantic import BaseModel, Field, model_serializer, root_validator
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
|
||||||
@@ -182,6 +182,7 @@ class CompletionRequest(BaseModel):
|
|||||||
skip_special_tokens: bool = True
|
skip_special_tokens: bool = True
|
||||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||||
session_params: Optional[Dict] = None
|
session_params: Optional[Dict] = None
|
||||||
|
return_hidden_states: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
class CompletionResponseChoice(BaseModel):
|
class CompletionResponseChoice(BaseModel):
|
||||||
@@ -190,6 +191,11 @@ class CompletionResponseChoice(BaseModel):
|
|||||||
logprobs: Optional[LogProbs] = None
|
logprobs: Optional[LogProbs] = None
|
||||||
finish_reason: Literal["stop", "length", "content_filter"]
|
finish_reason: Literal["stop", "length", "content_filter"]
|
||||||
matched_stop: Union[None, int, str] = None
|
matched_stop: Union[None, int, str] = None
|
||||||
|
hidden_states: Optional[object] = None
|
||||||
|
|
||||||
|
@model_serializer
|
||||||
|
def _serialize(self):
|
||||||
|
return exclude_if_none(self, ["hidden_states"])
|
||||||
|
|
||||||
|
|
||||||
class CompletionResponse(BaseModel):
|
class CompletionResponse(BaseModel):
|
||||||
@@ -207,6 +213,11 @@ class CompletionResponseStreamChoice(BaseModel):
|
|||||||
logprobs: Optional[LogProbs] = None
|
logprobs: Optional[LogProbs] = None
|
||||||
finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None
|
finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None
|
||||||
matched_stop: Union[None, int, str] = None
|
matched_stop: Union[None, int, str] = None
|
||||||
|
hidden_states: Optional[object] = None
|
||||||
|
|
||||||
|
@model_serializer
|
||||||
|
def _serialize(self):
|
||||||
|
return exclude_if_none(self, ["hidden_states"])
|
||||||
|
|
||||||
|
|
||||||
class CompletionStreamResponse(BaseModel):
|
class CompletionStreamResponse(BaseModel):
|
||||||
@@ -400,6 +411,9 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
bootstrap_port: Optional[int] = None
|
bootstrap_port: Optional[int] = None
|
||||||
bootstrap_room: Optional[int] = None
|
bootstrap_room: Optional[int] = None
|
||||||
|
|
||||||
|
# Hidden States
|
||||||
|
return_hidden_states: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseModel):
|
class ChatMessage(BaseModel):
|
||||||
role: Optional[str] = None
|
role: Optional[str] = None
|
||||||
@@ -416,6 +430,11 @@ class ChatCompletionResponseChoice(BaseModel):
|
|||||||
"stop", "length", "tool_calls", "content_filter", "function_call"
|
"stop", "length", "tool_calls", "content_filter", "function_call"
|
||||||
]
|
]
|
||||||
matched_stop: Union[None, int, str] = None
|
matched_stop: Union[None, int, str] = None
|
||||||
|
hidden_states: Optional[object] = None
|
||||||
|
|
||||||
|
@model_serializer
|
||||||
|
def _serialize(self):
|
||||||
|
return exclude_if_none(self, ["hidden_states"])
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponse(BaseModel):
|
class ChatCompletionResponse(BaseModel):
|
||||||
@@ -432,6 +451,11 @@ class DeltaMessage(BaseModel):
|
|||||||
content: Optional[str] = None
|
content: Optional[str] = None
|
||||||
reasoning_content: Optional[str] = None
|
reasoning_content: Optional[str] = None
|
||||||
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
|
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
|
||||||
|
hidden_states: Optional[object] = None
|
||||||
|
|
||||||
|
@model_serializer
|
||||||
|
def _serialize(self):
|
||||||
|
return exclude_if_none(self, ["hidden_states"])
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||||
@@ -484,3 +508,8 @@ class EmbeddingResponse(BaseModel):
|
|||||||
model: str
|
model: str
|
||||||
object: str = "list"
|
object: str = "list"
|
||||||
usage: Optional[UsageInfo] = None
|
usage: Optional[UsageInfo] = None
|
||||||
|
|
||||||
|
|
||||||
|
def exclude_if_none(obj, field_names: List[str]):
|
||||||
|
omit_if_none_fields = {k for k, v in obj.model_fields.items() if k in field_names}
|
||||||
|
return {k: v for k, v in obj if k not in omit_if_none_fields or v is not None}
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
"""
|
"""
|
||||||
python3 -m unittest test_openai_server.TestOpenAIServer.test_batch
|
python3 -m unittest test_openai_server.TestOpenAIServer.test_batch
|
||||||
python3 -m unittest test_openai_server.TestOpenAIServer.test_completion
|
python3 -m unittest test_openai_server.TestOpenAIServer.test_completion
|
||||||
|
python3 -m unittest test_openai_server.TestOpenAIServer.test_completion_stream
|
||||||
|
python3 -m unittest test_openai_server.TestOpenAIServer.test_chat_completion
|
||||||
|
python3 -m unittest test_openai_server.TestOpenAIServer.test_chat_completion_stream
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
@@ -9,6 +11,7 @@ import re
|
|||||||
import time
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||||
@@ -43,7 +46,13 @@ class TestOpenAIServer(CustomTestCase):
|
|||||||
kill_process_tree(cls.process.pid)
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
def run_completion(
|
def run_completion(
|
||||||
self, echo, logprobs, use_list_input, parallel_sample_num, token_input
|
self,
|
||||||
|
echo,
|
||||||
|
logprobs,
|
||||||
|
use_list_input,
|
||||||
|
parallel_sample_num,
|
||||||
|
token_input,
|
||||||
|
return_hidden_states,
|
||||||
):
|
):
|
||||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
prompt = "The capital of France is"
|
prompt = "The capital of France is"
|
||||||
@@ -70,6 +79,7 @@ class TestOpenAIServer(CustomTestCase):
|
|||||||
echo=echo,
|
echo=echo,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
n=parallel_sample_num,
|
n=parallel_sample_num,
|
||||||
|
extra_body=dict(return_hidden_states=return_hidden_states),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(response.choices) == num_choices * parallel_sample_num
|
assert len(response.choices) == num_choices * parallel_sample_num
|
||||||
@@ -100,8 +110,26 @@ class TestOpenAIServer(CustomTestCase):
|
|||||||
assert response.usage.completion_tokens > 0
|
assert response.usage.completion_tokens > 0
|
||||||
assert response.usage.total_tokens > 0
|
assert response.usage.total_tokens > 0
|
||||||
|
|
||||||
|
if return_hidden_states:
|
||||||
|
hidden_states = response.choices[0].hidden_states
|
||||||
|
assert hidden_states is not None, "hidden_states was none"
|
||||||
|
hidden_states = np.asarray(hidden_states)
|
||||||
|
assert (
|
||||||
|
len(hidden_states.shape) == 1
|
||||||
|
), f"hidden_states shape is not correct, was {hidden_states.shape}"
|
||||||
|
else:
|
||||||
|
assert not hasattr(
|
||||||
|
response.choices[0], "hidden_states"
|
||||||
|
), "hidden_states was returned and should not have been"
|
||||||
|
|
||||||
def run_completion_stream(
|
def run_completion_stream(
|
||||||
self, echo, logprobs, use_list_input, parallel_sample_num, token_input
|
self,
|
||||||
|
echo,
|
||||||
|
logprobs,
|
||||||
|
use_list_input,
|
||||||
|
parallel_sample_num,
|
||||||
|
token_input,
|
||||||
|
return_hidden_states,
|
||||||
):
|
):
|
||||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
prompt = "The capital of France is"
|
prompt = "The capital of France is"
|
||||||
@@ -130,33 +158,44 @@ class TestOpenAIServer(CustomTestCase):
|
|||||||
stream=True,
|
stream=True,
|
||||||
stream_options={"include_usage": True},
|
stream_options={"include_usage": True},
|
||||||
n=parallel_sample_num,
|
n=parallel_sample_num,
|
||||||
|
extra_body=dict(return_hidden_states=return_hidden_states),
|
||||||
)
|
)
|
||||||
|
|
||||||
is_firsts = {}
|
is_firsts = {}
|
||||||
|
hidden_states = None
|
||||||
for response in generator:
|
for response in generator:
|
||||||
usage = response.usage
|
usage = response.usage
|
||||||
if usage is not None:
|
if usage is not None:
|
||||||
assert usage.prompt_tokens > 0
|
assert usage.prompt_tokens > 0, f"usage.prompt_tokens was zero"
|
||||||
assert usage.completion_tokens > 0
|
assert usage.completion_tokens > 0, f"usage.completion_tokens was zero"
|
||||||
assert usage.total_tokens > 0
|
assert usage.total_tokens > 0, f"usage.total_tokens was zero"
|
||||||
|
continue
|
||||||
|
|
||||||
|
if (
|
||||||
|
hasattr(response.choices[0], "hidden_states")
|
||||||
|
and response.choices[0].hidden_states is not None
|
||||||
|
):
|
||||||
|
hidden_states = response.choices[0].hidden_states
|
||||||
continue
|
continue
|
||||||
|
|
||||||
index = response.choices[0].index
|
index = response.choices[0].index
|
||||||
is_first = is_firsts.get(index, True)
|
is_first = is_firsts.get(index, True)
|
||||||
|
|
||||||
if logprobs:
|
if logprobs:
|
||||||
assert response.choices[0].logprobs
|
assert response.choices[0].logprobs, f"no logprobs in response"
|
||||||
assert isinstance(response.choices[0].logprobs.tokens[0], str)
|
assert isinstance(
|
||||||
|
response.choices[0].logprobs.tokens[0], str
|
||||||
|
), f"{response.choices[0].logprobs.tokens[0]} is not a string"
|
||||||
if not (is_first and echo):
|
if not (is_first and echo):
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
response.choices[0].logprobs.top_logprobs[0], dict
|
response.choices[0].logprobs.top_logprobs[0], dict
|
||||||
)
|
), f"top_logprobs was not a dictionary"
|
||||||
ret_num_top_logprobs = len(
|
ret_num_top_logprobs = len(
|
||||||
response.choices[0].logprobs.top_logprobs[0]
|
response.choices[0].logprobs.top_logprobs[0]
|
||||||
)
|
)
|
||||||
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map
|
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map
|
||||||
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
|
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
|
||||||
assert ret_num_top_logprobs > 0
|
assert ret_num_top_logprobs > 0, f"ret_num_top_logprobs was 0"
|
||||||
|
|
||||||
if is_first:
|
if is_first:
|
||||||
if echo:
|
if echo:
|
||||||
@@ -164,15 +203,29 @@ class TestOpenAIServer(CustomTestCase):
|
|||||||
prompt
|
prompt
|
||||||
), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {is_first}"
|
), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {is_first}"
|
||||||
is_firsts[index] = False
|
is_firsts[index] = False
|
||||||
assert response.id
|
assert response.id, f"no id in response"
|
||||||
assert response.created
|
assert response.created, f"no created in response"
|
||||||
|
|
||||||
for index in [i for i in range(parallel_sample_num * num_choices)]:
|
for index in [i for i in range(parallel_sample_num * num_choices)]:
|
||||||
assert not is_firsts.get(
|
assert not is_firsts.get(
|
||||||
index, True
|
index, True
|
||||||
), f"index {index} is not found in the response"
|
), f"index {index} is not found in the response"
|
||||||
|
|
||||||
def run_chat_completion(self, logprobs, parallel_sample_num):
|
if return_hidden_states:
|
||||||
|
assert hidden_states is not None, "hidden_states is not returned"
|
||||||
|
try:
|
||||||
|
hidden_states = np.asarray(hidden_states)
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Failed to convert hidden states to numpy array: {e}")
|
||||||
|
assert (
|
||||||
|
len(hidden_states.shape) == 1
|
||||||
|
), f"hidden_states shape is not correct, was {hidden_states.shape}"
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
hidden_states is None
|
||||||
|
), "hidden_states was returned and should not have been"
|
||||||
|
|
||||||
|
def run_chat_completion(self, logprobs, parallel_sample_num, return_hidden_states):
|
||||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
response = client.chat.completions.create(
|
response = client.chat.completions.create(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
@@ -187,6 +240,7 @@ class TestOpenAIServer(CustomTestCase):
|
|||||||
logprobs=logprobs is not None and logprobs > 0,
|
logprobs=logprobs is not None and logprobs > 0,
|
||||||
top_logprobs=logprobs,
|
top_logprobs=logprobs,
|
||||||
n=parallel_sample_num,
|
n=parallel_sample_num,
|
||||||
|
extra_body=dict(return_hidden_states=return_hidden_states),
|
||||||
)
|
)
|
||||||
|
|
||||||
if logprobs:
|
if logprobs:
|
||||||
@@ -210,7 +264,21 @@ class TestOpenAIServer(CustomTestCase):
|
|||||||
assert response.usage.completion_tokens > 0
|
assert response.usage.completion_tokens > 0
|
||||||
assert response.usage.total_tokens > 0
|
assert response.usage.total_tokens > 0
|
||||||
|
|
||||||
def run_chat_completion_stream(self, logprobs, parallel_sample_num=1):
|
if return_hidden_states:
|
||||||
|
hidden_states = response.choices[0].hidden_states
|
||||||
|
assert hidden_states is not None, "hidden_states is not returned"
|
||||||
|
hidden_states = np.asarray(hidden_states)
|
||||||
|
assert (
|
||||||
|
len(hidden_states.shape) == 1
|
||||||
|
), f"hidden_states shape is not correct, was {hidden_states.shape}"
|
||||||
|
else:
|
||||||
|
assert not hasattr(
|
||||||
|
response.choices[0], "hidden_states"
|
||||||
|
), "hidden_states was returned and should not have been"
|
||||||
|
|
||||||
|
def run_chat_completion_stream(
|
||||||
|
self, logprobs, parallel_sample_num=1, return_hidden_states=False
|
||||||
|
):
|
||||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
generator = client.chat.completions.create(
|
generator = client.chat.completions.create(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
@@ -224,40 +292,55 @@ class TestOpenAIServer(CustomTestCase):
|
|||||||
stream=True,
|
stream=True,
|
||||||
stream_options={"include_usage": True},
|
stream_options={"include_usage": True},
|
||||||
n=parallel_sample_num,
|
n=parallel_sample_num,
|
||||||
|
extra_body=dict(return_hidden_states=return_hidden_states),
|
||||||
)
|
)
|
||||||
|
|
||||||
is_firsts = {}
|
is_firsts = {}
|
||||||
|
hidden_states = None
|
||||||
|
top_logprob_tokens = []
|
||||||
for response in generator:
|
for response in generator:
|
||||||
usage = response.usage
|
usage = response.usage
|
||||||
if usage is not None:
|
if usage is not None:
|
||||||
assert usage.prompt_tokens > 0
|
assert usage.prompt_tokens > 0, f"usage.prompt_tokens was zero"
|
||||||
assert usage.completion_tokens > 0
|
assert usage.completion_tokens > 0, f"usage.completion_tokens was zero"
|
||||||
assert usage.total_tokens > 0
|
assert usage.total_tokens > 0, f"usage.total_tokens was zero"
|
||||||
|
continue
|
||||||
|
|
||||||
|
if hasattr(response.choices[0].delta, "hidden_states"):
|
||||||
|
hidden_states = response.choices[0].delta.hidden_states
|
||||||
continue
|
continue
|
||||||
|
|
||||||
index = response.choices[0].index
|
index = response.choices[0].index
|
||||||
data = response.choices[0].delta
|
data = response.choices[0].delta
|
||||||
|
|
||||||
if is_firsts.get(index, True):
|
if is_firsts.get(index, True):
|
||||||
assert data.role == "assistant"
|
assert (
|
||||||
|
data.role == "assistant"
|
||||||
|
), f"data.role was not 'assistant' for first chunk"
|
||||||
is_firsts[index] = False
|
is_firsts[index] = False
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if logprobs:
|
if logprobs:
|
||||||
assert response.choices[0].logprobs
|
assert response.choices[0].logprobs, f"logprobs was not returned"
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
response.choices[0].logprobs.content[0].top_logprobs[0].token, str
|
response.choices[0].logprobs.content[0].top_logprobs[0].token, str
|
||||||
)
|
), f"top_logprobs token was not a string"
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
response.choices[0].logprobs.content[0].top_logprobs, list
|
response.choices[0].logprobs.content[0].top_logprobs, list
|
||||||
)
|
), f"top_logprobs was not a list"
|
||||||
ret_num_top_logprobs = len(
|
ret_num_top_logprobs = len(
|
||||||
response.choices[0].logprobs.content[0].top_logprobs
|
response.choices[0].logprobs.content[0].top_logprobs
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
ret_num_top_logprobs == logprobs
|
ret_num_top_logprobs == logprobs
|
||||||
), f"{ret_num_top_logprobs} vs {logprobs}"
|
), f"{ret_num_top_logprobs} vs {logprobs}"
|
||||||
|
top_logprob_tokens.append(
|
||||||
|
response.choices[0].logprobs.content[0].top_logprobs[0].token
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
len(top_logprob_tokens) <= 2 or len(set(top_logprob_tokens)) > 1
|
||||||
|
), "Top Logprob tokens should not consistent of the same token repeated"
|
||||||
assert (
|
assert (
|
||||||
isinstance(data.content, str)
|
isinstance(data.content, str)
|
||||||
or isinstance(data.reasoning_content, str)
|
or isinstance(data.reasoning_content, str)
|
||||||
@@ -272,6 +355,20 @@ class TestOpenAIServer(CustomTestCase):
|
|||||||
index, True
|
index, True
|
||||||
), f"index {index} is not found in the response"
|
), f"index {index} is not found in the response"
|
||||||
|
|
||||||
|
if return_hidden_states:
|
||||||
|
assert hidden_states is not None, "hidden_states is not returned"
|
||||||
|
try:
|
||||||
|
hidden_states = np.asarray(hidden_states)
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Failed to convert hidden states to numpy array: {e}")
|
||||||
|
assert (
|
||||||
|
len(hidden_states.shape) == 1
|
||||||
|
), f"hidden_states shape is not correct, was {hidden_states.shape}"
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
hidden_states is None
|
||||||
|
), "hidden_states was returned and should not have been"
|
||||||
|
|
||||||
def _create_batch(self, mode, client):
|
def _create_batch(self, mode, client):
|
||||||
if mode == "completion":
|
if mode == "completion":
|
||||||
input_file_path = "complete_input.jsonl"
|
input_file_path = "complete_input.jsonl"
|
||||||
@@ -419,43 +516,53 @@ class TestOpenAIServer(CustomTestCase):
|
|||||||
assert del_response.deleted
|
assert del_response.deleted
|
||||||
|
|
||||||
def test_completion(self):
|
def test_completion(self):
|
||||||
for echo in [False, True]:
|
for return_hidden_states in [False, True]:
|
||||||
for logprobs in [None, 5]:
|
for echo in [False, True]:
|
||||||
for use_list_input in [True, False]:
|
for logprobs in [None, 5]:
|
||||||
for parallel_sample_num in [1, 2]:
|
for use_list_input in [True, False]:
|
||||||
for token_input in [False, True]:
|
for parallel_sample_num in [1, 2]:
|
||||||
self.run_completion(
|
for token_input in [False, True]:
|
||||||
echo,
|
self.run_completion(
|
||||||
logprobs,
|
echo,
|
||||||
use_list_input,
|
logprobs,
|
||||||
parallel_sample_num,
|
use_list_input,
|
||||||
token_input,
|
parallel_sample_num,
|
||||||
)
|
token_input,
|
||||||
|
return_hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
def test_completion_stream(self):
|
def test_completion_stream(self):
|
||||||
# parallel sampling and list input are not supported in streaming mode
|
# parallel sampling and list input are not supported in streaming mode
|
||||||
for echo in [False, True]:
|
for return_hidden_states in [False, True]:
|
||||||
for logprobs in [None, 5]:
|
for echo in [False, True]:
|
||||||
for use_list_input in [True, False]:
|
for logprobs in [None, 5]:
|
||||||
for parallel_sample_num in [1, 2]:
|
for use_list_input in [True, False]:
|
||||||
for token_input in [False, True]:
|
for parallel_sample_num in [1, 2]:
|
||||||
self.run_completion_stream(
|
for token_input in [False, True]:
|
||||||
echo,
|
self.run_completion_stream(
|
||||||
logprobs,
|
echo,
|
||||||
use_list_input,
|
logprobs,
|
||||||
parallel_sample_num,
|
use_list_input,
|
||||||
token_input,
|
parallel_sample_num,
|
||||||
)
|
token_input,
|
||||||
|
return_hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
def test_chat_completion(self):
|
def test_chat_completion(self):
|
||||||
for logprobs in [None, 5]:
|
for return_hidden_states in [False, True]:
|
||||||
for parallel_sample_num in [1, 2]:
|
for logprobs in [None, 5]:
|
||||||
self.run_chat_completion(logprobs, parallel_sample_num)
|
for parallel_sample_num in [1, 2]:
|
||||||
|
self.run_chat_completion(
|
||||||
|
logprobs, parallel_sample_num, return_hidden_states
|
||||||
|
)
|
||||||
|
|
||||||
def test_chat_completion_stream(self):
|
def test_chat_completion_stream(self):
|
||||||
for logprobs in [None, 5]:
|
for return_hidden_states in [False, True]:
|
||||||
for parallel_sample_num in [1, 2]:
|
for logprobs in [None, 5]:
|
||||||
self.run_chat_completion_stream(logprobs, parallel_sample_num)
|
for parallel_sample_num in [1, 2]:
|
||||||
|
self.run_chat_completion_stream(
|
||||||
|
logprobs, parallel_sample_num, return_hidden_states
|
||||||
|
)
|
||||||
|
|
||||||
def test_batch(self):
|
def test_batch(self):
|
||||||
for mode in ["completion", "chat"]:
|
for mode in ["completion", "chat"]:
|
||||||
|
|||||||
Reference in New Issue
Block a user