Implement return_hidden_states for the OpenAI API (#6137)
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
"""
|
||||
python3 -m unittest test_openai_server.TestOpenAIServer.test_batch
|
||||
python3 -m unittest test_openai_server.TestOpenAIServer.test_completion
|
||||
|
||||
python3 -m unittest test_openai_server.TestOpenAIServer.test_completion_stream
|
||||
python3 -m unittest test_openai_server.TestOpenAIServer.test_chat_completion
|
||||
python3 -m unittest test_openai_server.TestOpenAIServer.test_chat_completion_stream
|
||||
"""
|
||||
|
||||
import json
|
||||
@@ -9,6 +11,7 @@ import re
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import openai
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
@@ -43,7 +46,13 @@ class TestOpenAIServer(CustomTestCase):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def run_completion(
|
||||
self, echo, logprobs, use_list_input, parallel_sample_num, token_input
|
||||
self,
|
||||
echo,
|
||||
logprobs,
|
||||
use_list_input,
|
||||
parallel_sample_num,
|
||||
token_input,
|
||||
return_hidden_states,
|
||||
):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
prompt = "The capital of France is"
|
||||
@@ -70,6 +79,7 @@ class TestOpenAIServer(CustomTestCase):
|
||||
echo=echo,
|
||||
logprobs=logprobs,
|
||||
n=parallel_sample_num,
|
||||
extra_body=dict(return_hidden_states=return_hidden_states),
|
||||
)
|
||||
|
||||
assert len(response.choices) == num_choices * parallel_sample_num
|
||||
@@ -100,8 +110,26 @@ class TestOpenAIServer(CustomTestCase):
|
||||
assert response.usage.completion_tokens > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
if return_hidden_states:
|
||||
hidden_states = response.choices[0].hidden_states
|
||||
assert hidden_states is not None, "hidden_states was none"
|
||||
hidden_states = np.asarray(hidden_states)
|
||||
assert (
|
||||
len(hidden_states.shape) == 1
|
||||
), f"hidden_states shape is not correct, was {hidden_states.shape}"
|
||||
else:
|
||||
assert not hasattr(
|
||||
response.choices[0], "hidden_states"
|
||||
), "hidden_states was returned and should not have been"
|
||||
|
||||
def run_completion_stream(
|
||||
self, echo, logprobs, use_list_input, parallel_sample_num, token_input
|
||||
self,
|
||||
echo,
|
||||
logprobs,
|
||||
use_list_input,
|
||||
parallel_sample_num,
|
||||
token_input,
|
||||
return_hidden_states,
|
||||
):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
prompt = "The capital of France is"
|
||||
@@ -130,33 +158,44 @@ class TestOpenAIServer(CustomTestCase):
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
n=parallel_sample_num,
|
||||
extra_body=dict(return_hidden_states=return_hidden_states),
|
||||
)
|
||||
|
||||
is_firsts = {}
|
||||
hidden_states = None
|
||||
for response in generator:
|
||||
usage = response.usage
|
||||
if usage is not None:
|
||||
assert usage.prompt_tokens > 0
|
||||
assert usage.completion_tokens > 0
|
||||
assert usage.total_tokens > 0
|
||||
assert usage.prompt_tokens > 0, f"usage.prompt_tokens was zero"
|
||||
assert usage.completion_tokens > 0, f"usage.completion_tokens was zero"
|
||||
assert usage.total_tokens > 0, f"usage.total_tokens was zero"
|
||||
continue
|
||||
|
||||
if (
|
||||
hasattr(response.choices[0], "hidden_states")
|
||||
and response.choices[0].hidden_states is not None
|
||||
):
|
||||
hidden_states = response.choices[0].hidden_states
|
||||
continue
|
||||
|
||||
index = response.choices[0].index
|
||||
is_first = is_firsts.get(index, True)
|
||||
|
||||
if logprobs:
|
||||
assert response.choices[0].logprobs
|
||||
assert isinstance(response.choices[0].logprobs.tokens[0], str)
|
||||
assert response.choices[0].logprobs, f"no logprobs in response"
|
||||
assert isinstance(
|
||||
response.choices[0].logprobs.tokens[0], str
|
||||
), f"{response.choices[0].logprobs.tokens[0]} is not a string"
|
||||
if not (is_first and echo):
|
||||
assert isinstance(
|
||||
response.choices[0].logprobs.top_logprobs[0], dict
|
||||
)
|
||||
), f"top_logprobs was not a dictionary"
|
||||
ret_num_top_logprobs = len(
|
||||
response.choices[0].logprobs.top_logprobs[0]
|
||||
)
|
||||
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map
|
||||
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
|
||||
assert ret_num_top_logprobs > 0
|
||||
assert ret_num_top_logprobs > 0, f"ret_num_top_logprobs was 0"
|
||||
|
||||
if is_first:
|
||||
if echo:
|
||||
@@ -164,15 +203,29 @@ class TestOpenAIServer(CustomTestCase):
|
||||
prompt
|
||||
), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {is_first}"
|
||||
is_firsts[index] = False
|
||||
assert response.id
|
||||
assert response.created
|
||||
assert response.id, f"no id in response"
|
||||
assert response.created, f"no created in response"
|
||||
|
||||
for index in [i for i in range(parallel_sample_num * num_choices)]:
|
||||
assert not is_firsts.get(
|
||||
index, True
|
||||
), f"index {index} is not found in the response"
|
||||
|
||||
def run_chat_completion(self, logprobs, parallel_sample_num):
|
||||
if return_hidden_states:
|
||||
assert hidden_states is not None, "hidden_states is not returned"
|
||||
try:
|
||||
hidden_states = np.asarray(hidden_states)
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to convert hidden states to numpy array: {e}")
|
||||
assert (
|
||||
len(hidden_states.shape) == 1
|
||||
), f"hidden_states shape is not correct, was {hidden_states.shape}"
|
||||
else:
|
||||
assert (
|
||||
hidden_states is None
|
||||
), "hidden_states was returned and should not have been"
|
||||
|
||||
def run_chat_completion(self, logprobs, parallel_sample_num, return_hidden_states):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
response = client.chat.completions.create(
|
||||
model=self.model,
|
||||
@@ -187,6 +240,7 @@ class TestOpenAIServer(CustomTestCase):
|
||||
logprobs=logprobs is not None and logprobs > 0,
|
||||
top_logprobs=logprobs,
|
||||
n=parallel_sample_num,
|
||||
extra_body=dict(return_hidden_states=return_hidden_states),
|
||||
)
|
||||
|
||||
if logprobs:
|
||||
@@ -210,7 +264,21 @@ class TestOpenAIServer(CustomTestCase):
|
||||
assert response.usage.completion_tokens > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
def run_chat_completion_stream(self, logprobs, parallel_sample_num=1):
|
||||
if return_hidden_states:
|
||||
hidden_states = response.choices[0].hidden_states
|
||||
assert hidden_states is not None, "hidden_states is not returned"
|
||||
hidden_states = np.asarray(hidden_states)
|
||||
assert (
|
||||
len(hidden_states.shape) == 1
|
||||
), f"hidden_states shape is not correct, was {hidden_states.shape}"
|
||||
else:
|
||||
assert not hasattr(
|
||||
response.choices[0], "hidden_states"
|
||||
), "hidden_states was returned and should not have been"
|
||||
|
||||
def run_chat_completion_stream(
|
||||
self, logprobs, parallel_sample_num=1, return_hidden_states=False
|
||||
):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
generator = client.chat.completions.create(
|
||||
model=self.model,
|
||||
@@ -224,40 +292,55 @@ class TestOpenAIServer(CustomTestCase):
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
n=parallel_sample_num,
|
||||
extra_body=dict(return_hidden_states=return_hidden_states),
|
||||
)
|
||||
|
||||
is_firsts = {}
|
||||
hidden_states = None
|
||||
top_logprob_tokens = []
|
||||
for response in generator:
|
||||
usage = response.usage
|
||||
if usage is not None:
|
||||
assert usage.prompt_tokens > 0
|
||||
assert usage.completion_tokens > 0
|
||||
assert usage.total_tokens > 0
|
||||
assert usage.prompt_tokens > 0, f"usage.prompt_tokens was zero"
|
||||
assert usage.completion_tokens > 0, f"usage.completion_tokens was zero"
|
||||
assert usage.total_tokens > 0, f"usage.total_tokens was zero"
|
||||
continue
|
||||
|
||||
if hasattr(response.choices[0].delta, "hidden_states"):
|
||||
hidden_states = response.choices[0].delta.hidden_states
|
||||
continue
|
||||
|
||||
index = response.choices[0].index
|
||||
data = response.choices[0].delta
|
||||
|
||||
if is_firsts.get(index, True):
|
||||
assert data.role == "assistant"
|
||||
assert (
|
||||
data.role == "assistant"
|
||||
), f"data.role was not 'assistant' for first chunk"
|
||||
is_firsts[index] = False
|
||||
continue
|
||||
|
||||
if logprobs:
|
||||
assert response.choices[0].logprobs
|
||||
assert response.choices[0].logprobs, f"logprobs was not returned"
|
||||
assert isinstance(
|
||||
response.choices[0].logprobs.content[0].top_logprobs[0].token, str
|
||||
)
|
||||
), f"top_logprobs token was not a string"
|
||||
assert isinstance(
|
||||
response.choices[0].logprobs.content[0].top_logprobs, list
|
||||
)
|
||||
), f"top_logprobs was not a list"
|
||||
ret_num_top_logprobs = len(
|
||||
response.choices[0].logprobs.content[0].top_logprobs
|
||||
)
|
||||
assert (
|
||||
ret_num_top_logprobs == logprobs
|
||||
), f"{ret_num_top_logprobs} vs {logprobs}"
|
||||
top_logprob_tokens.append(
|
||||
response.choices[0].logprobs.content[0].top_logprobs[0].token
|
||||
)
|
||||
|
||||
assert (
|
||||
len(top_logprob_tokens) <= 2 or len(set(top_logprob_tokens)) > 1
|
||||
), "Top Logprob tokens should not consistent of the same token repeated"
|
||||
assert (
|
||||
isinstance(data.content, str)
|
||||
or isinstance(data.reasoning_content, str)
|
||||
@@ -272,6 +355,20 @@ class TestOpenAIServer(CustomTestCase):
|
||||
index, True
|
||||
), f"index {index} is not found in the response"
|
||||
|
||||
if return_hidden_states:
|
||||
assert hidden_states is not None, "hidden_states is not returned"
|
||||
try:
|
||||
hidden_states = np.asarray(hidden_states)
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to convert hidden states to numpy array: {e}")
|
||||
assert (
|
||||
len(hidden_states.shape) == 1
|
||||
), f"hidden_states shape is not correct, was {hidden_states.shape}"
|
||||
else:
|
||||
assert (
|
||||
hidden_states is None
|
||||
), "hidden_states was returned and should not have been"
|
||||
|
||||
def _create_batch(self, mode, client):
|
||||
if mode == "completion":
|
||||
input_file_path = "complete_input.jsonl"
|
||||
@@ -419,43 +516,53 @@ class TestOpenAIServer(CustomTestCase):
|
||||
assert del_response.deleted
|
||||
|
||||
def test_completion(self):
|
||||
for echo in [False, True]:
|
||||
for logprobs in [None, 5]:
|
||||
for use_list_input in [True, False]:
|
||||
for parallel_sample_num in [1, 2]:
|
||||
for token_input in [False, True]:
|
||||
self.run_completion(
|
||||
echo,
|
||||
logprobs,
|
||||
use_list_input,
|
||||
parallel_sample_num,
|
||||
token_input,
|
||||
)
|
||||
for return_hidden_states in [False, True]:
|
||||
for echo in [False, True]:
|
||||
for logprobs in [None, 5]:
|
||||
for use_list_input in [True, False]:
|
||||
for parallel_sample_num in [1, 2]:
|
||||
for token_input in [False, True]:
|
||||
self.run_completion(
|
||||
echo,
|
||||
logprobs,
|
||||
use_list_input,
|
||||
parallel_sample_num,
|
||||
token_input,
|
||||
return_hidden_states,
|
||||
)
|
||||
|
||||
def test_completion_stream(self):
|
||||
# parallel sampling and list input are not supported in streaming mode
|
||||
for echo in [False, True]:
|
||||
for logprobs in [None, 5]:
|
||||
for use_list_input in [True, False]:
|
||||
for parallel_sample_num in [1, 2]:
|
||||
for token_input in [False, True]:
|
||||
self.run_completion_stream(
|
||||
echo,
|
||||
logprobs,
|
||||
use_list_input,
|
||||
parallel_sample_num,
|
||||
token_input,
|
||||
)
|
||||
for return_hidden_states in [False, True]:
|
||||
for echo in [False, True]:
|
||||
for logprobs in [None, 5]:
|
||||
for use_list_input in [True, False]:
|
||||
for parallel_sample_num in [1, 2]:
|
||||
for token_input in [False, True]:
|
||||
self.run_completion_stream(
|
||||
echo,
|
||||
logprobs,
|
||||
use_list_input,
|
||||
parallel_sample_num,
|
||||
token_input,
|
||||
return_hidden_states,
|
||||
)
|
||||
|
||||
def test_chat_completion(self):
|
||||
for logprobs in [None, 5]:
|
||||
for parallel_sample_num in [1, 2]:
|
||||
self.run_chat_completion(logprobs, parallel_sample_num)
|
||||
for return_hidden_states in [False, True]:
|
||||
for logprobs in [None, 5]:
|
||||
for parallel_sample_num in [1, 2]:
|
||||
self.run_chat_completion(
|
||||
logprobs, parallel_sample_num, return_hidden_states
|
||||
)
|
||||
|
||||
def test_chat_completion_stream(self):
|
||||
for logprobs in [None, 5]:
|
||||
for parallel_sample_num in [1, 2]:
|
||||
self.run_chat_completion_stream(logprobs, parallel_sample_num)
|
||||
for return_hidden_states in [False, True]:
|
||||
for logprobs in [None, 5]:
|
||||
for parallel_sample_num in [1, 2]:
|
||||
self.run_chat_completion_stream(
|
||||
logprobs, parallel_sample_num, return_hidden_states
|
||||
)
|
||||
|
||||
def test_batch(self):
|
||||
for mode in ["completion", "chat"]:
|
||||
|
||||
Reference in New Issue
Block a user