Implement return_hidden_states for the OpenAI API (#6137)

This commit is contained in:
kyle-pena-kuzco
2025-05-19 01:30:25 -04:00
committed by GitHub
parent 31c9569bb8
commit 4f39bcf7ab
3 changed files with 275 additions and 53 deletions

View File

@@ -1,7 +1,9 @@
"""
python3 -m unittest test_openai_server.TestOpenAIServer.test_batch
python3 -m unittest test_openai_server.TestOpenAIServer.test_completion
python3 -m unittest test_openai_server.TestOpenAIServer.test_completion_stream
python3 -m unittest test_openai_server.TestOpenAIServer.test_chat_completion
python3 -m unittest test_openai_server.TestOpenAIServer.test_chat_completion_stream
"""
import json
@@ -9,6 +11,7 @@ import re
import time
import unittest
import numpy as np
import openai
from sglang.srt.hf_transformers_utils import get_tokenizer
@@ -43,7 +46,13 @@ class TestOpenAIServer(CustomTestCase):
kill_process_tree(cls.process.pid)
def run_completion(
self, echo, logprobs, use_list_input, parallel_sample_num, token_input
self,
echo,
logprobs,
use_list_input,
parallel_sample_num,
token_input,
return_hidden_states,
):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
prompt = "The capital of France is"
@@ -70,6 +79,7 @@ class TestOpenAIServer(CustomTestCase):
echo=echo,
logprobs=logprobs,
n=parallel_sample_num,
extra_body=dict(return_hidden_states=return_hidden_states),
)
assert len(response.choices) == num_choices * parallel_sample_num
@@ -100,8 +110,26 @@ class TestOpenAIServer(CustomTestCase):
assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0
if return_hidden_states:
hidden_states = response.choices[0].hidden_states
assert hidden_states is not None, "hidden_states was none"
hidden_states = np.asarray(hidden_states)
assert (
len(hidden_states.shape) == 1
), f"hidden_states shape is not correct, was {hidden_states.shape}"
else:
assert not hasattr(
response.choices[0], "hidden_states"
), "hidden_states was returned and should not have been"
def run_completion_stream(
self, echo, logprobs, use_list_input, parallel_sample_num, token_input
self,
echo,
logprobs,
use_list_input,
parallel_sample_num,
token_input,
return_hidden_states,
):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
prompt = "The capital of France is"
@@ -130,33 +158,44 @@ class TestOpenAIServer(CustomTestCase):
stream=True,
stream_options={"include_usage": True},
n=parallel_sample_num,
extra_body=dict(return_hidden_states=return_hidden_states),
)
is_firsts = {}
hidden_states = None
for response in generator:
usage = response.usage
if usage is not None:
assert usage.prompt_tokens > 0
assert usage.completion_tokens > 0
assert usage.total_tokens > 0
assert usage.prompt_tokens > 0, f"usage.prompt_tokens was zero"
assert usage.completion_tokens > 0, f"usage.completion_tokens was zero"
assert usage.total_tokens > 0, f"usage.total_tokens was zero"
continue
if (
hasattr(response.choices[0], "hidden_states")
and response.choices[0].hidden_states is not None
):
hidden_states = response.choices[0].hidden_states
continue
index = response.choices[0].index
is_first = is_firsts.get(index, True)
if logprobs:
assert response.choices[0].logprobs
assert isinstance(response.choices[0].logprobs.tokens[0], str)
assert response.choices[0].logprobs, f"no logprobs in response"
assert isinstance(
response.choices[0].logprobs.tokens[0], str
), f"{response.choices[0].logprobs.tokens[0]} is not a string"
if not (is_first and echo):
assert isinstance(
response.choices[0].logprobs.top_logprobs[0], dict
)
), f"top_logprobs was not a dictionary"
ret_num_top_logprobs = len(
response.choices[0].logprobs.top_logprobs[0]
)
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
assert ret_num_top_logprobs > 0
assert ret_num_top_logprobs > 0, f"ret_num_top_logprobs was 0"
if is_first:
if echo:
@@ -164,15 +203,29 @@ class TestOpenAIServer(CustomTestCase):
prompt
), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {is_first}"
is_firsts[index] = False
assert response.id
assert response.created
assert response.id, f"no id in response"
assert response.created, f"no created in response"
for index in [i for i in range(parallel_sample_num * num_choices)]:
assert not is_firsts.get(
index, True
), f"index {index} is not found in the response"
def run_chat_completion(self, logprobs, parallel_sample_num):
if return_hidden_states:
assert hidden_states is not None, "hidden_states is not returned"
try:
hidden_states = np.asarray(hidden_states)
except Exception as e:
raise Exception(f"Failed to convert hidden states to numpy array: {e}")
assert (
len(hidden_states.shape) == 1
), f"hidden_states shape is not correct, was {hidden_states.shape}"
else:
assert (
hidden_states is None
), "hidden_states was returned and should not have been"
def run_chat_completion(self, logprobs, parallel_sample_num, return_hidden_states):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.chat.completions.create(
model=self.model,
@@ -187,6 +240,7 @@ class TestOpenAIServer(CustomTestCase):
logprobs=logprobs is not None and logprobs > 0,
top_logprobs=logprobs,
n=parallel_sample_num,
extra_body=dict(return_hidden_states=return_hidden_states),
)
if logprobs:
@@ -210,7 +264,21 @@ class TestOpenAIServer(CustomTestCase):
assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0
def run_chat_completion_stream(self, logprobs, parallel_sample_num=1):
if return_hidden_states:
hidden_states = response.choices[0].hidden_states
assert hidden_states is not None, "hidden_states is not returned"
hidden_states = np.asarray(hidden_states)
assert (
len(hidden_states.shape) == 1
), f"hidden_states shape is not correct, was {hidden_states.shape}"
else:
assert not hasattr(
response.choices[0], "hidden_states"
), "hidden_states was returned and should not have been"
def run_chat_completion_stream(
self, logprobs, parallel_sample_num=1, return_hidden_states=False
):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
generator = client.chat.completions.create(
model=self.model,
@@ -224,40 +292,55 @@ class TestOpenAIServer(CustomTestCase):
stream=True,
stream_options={"include_usage": True},
n=parallel_sample_num,
extra_body=dict(return_hidden_states=return_hidden_states),
)
is_firsts = {}
hidden_states = None
top_logprob_tokens = []
for response in generator:
usage = response.usage
if usage is not None:
assert usage.prompt_tokens > 0
assert usage.completion_tokens > 0
assert usage.total_tokens > 0
assert usage.prompt_tokens > 0, f"usage.prompt_tokens was zero"
assert usage.completion_tokens > 0, f"usage.completion_tokens was zero"
assert usage.total_tokens > 0, f"usage.total_tokens was zero"
continue
if hasattr(response.choices[0].delta, "hidden_states"):
hidden_states = response.choices[0].delta.hidden_states
continue
index = response.choices[0].index
data = response.choices[0].delta
if is_firsts.get(index, True):
assert data.role == "assistant"
assert (
data.role == "assistant"
), f"data.role was not 'assistant' for first chunk"
is_firsts[index] = False
continue
if logprobs:
assert response.choices[0].logprobs
assert response.choices[0].logprobs, f"logprobs was not returned"
assert isinstance(
response.choices[0].logprobs.content[0].top_logprobs[0].token, str
)
), f"top_logprobs token was not a string"
assert isinstance(
response.choices[0].logprobs.content[0].top_logprobs, list
)
), f"top_logprobs was not a list"
ret_num_top_logprobs = len(
response.choices[0].logprobs.content[0].top_logprobs
)
assert (
ret_num_top_logprobs == logprobs
), f"{ret_num_top_logprobs} vs {logprobs}"
top_logprob_tokens.append(
response.choices[0].logprobs.content[0].top_logprobs[0].token
)
assert (
len(top_logprob_tokens) <= 2 or len(set(top_logprob_tokens)) > 1
), "Top Logprob tokens should not consistent of the same token repeated"
assert (
isinstance(data.content, str)
or isinstance(data.reasoning_content, str)
@@ -272,6 +355,20 @@ class TestOpenAIServer(CustomTestCase):
index, True
), f"index {index} is not found in the response"
if return_hidden_states:
assert hidden_states is not None, "hidden_states is not returned"
try:
hidden_states = np.asarray(hidden_states)
except Exception as e:
raise Exception(f"Failed to convert hidden states to numpy array: {e}")
assert (
len(hidden_states.shape) == 1
), f"hidden_states shape is not correct, was {hidden_states.shape}"
else:
assert (
hidden_states is None
), "hidden_states was returned and should not have been"
def _create_batch(self, mode, client):
if mode == "completion":
input_file_path = "complete_input.jsonl"
@@ -419,43 +516,53 @@ class TestOpenAIServer(CustomTestCase):
assert del_response.deleted
def test_completion(self):
for echo in [False, True]:
for logprobs in [None, 5]:
for use_list_input in [True, False]:
for parallel_sample_num in [1, 2]:
for token_input in [False, True]:
self.run_completion(
echo,
logprobs,
use_list_input,
parallel_sample_num,
token_input,
)
for return_hidden_states in [False, True]:
for echo in [False, True]:
for logprobs in [None, 5]:
for use_list_input in [True, False]:
for parallel_sample_num in [1, 2]:
for token_input in [False, True]:
self.run_completion(
echo,
logprobs,
use_list_input,
parallel_sample_num,
token_input,
return_hidden_states,
)
def test_completion_stream(self):
# parallel sampling and list input are not supported in streaming mode
for echo in [False, True]:
for logprobs in [None, 5]:
for use_list_input in [True, False]:
for parallel_sample_num in [1, 2]:
for token_input in [False, True]:
self.run_completion_stream(
echo,
logprobs,
use_list_input,
parallel_sample_num,
token_input,
)
for return_hidden_states in [False, True]:
for echo in [False, True]:
for logprobs in [None, 5]:
for use_list_input in [True, False]:
for parallel_sample_num in [1, 2]:
for token_input in [False, True]:
self.run_completion_stream(
echo,
logprobs,
use_list_input,
parallel_sample_num,
token_input,
return_hidden_states,
)
def test_chat_completion(self):
for logprobs in [None, 5]:
for parallel_sample_num in [1, 2]:
self.run_chat_completion(logprobs, parallel_sample_num)
for return_hidden_states in [False, True]:
for logprobs in [None, 5]:
for parallel_sample_num in [1, 2]:
self.run_chat_completion(
logprobs, parallel_sample_num, return_hidden_states
)
def test_chat_completion_stream(self):
for logprobs in [None, 5]:
for parallel_sample_num in [1, 2]:
self.run_chat_completion_stream(logprobs, parallel_sample_num)
for return_hidden_states in [False, True]:
for logprobs in [None, 5]:
for parallel_sample_num in [1, 2]:
self.run_chat_completion_stream(
logprobs, parallel_sample_num, return_hidden_states
)
def test_batch(self):
for mode in ["completion", "chat"]: