Open AI API hidden states (#6716)
This commit is contained in:
@@ -59,6 +59,7 @@ suites = {
|
||||
TestFile("test_openai_adapter.py", 1),
|
||||
TestFile("test_openai_function_calling.py", 60),
|
||||
TestFile("test_openai_server.py", 149),
|
||||
TestFile("test_openai_server_hidden_states.py", 240),
|
||||
TestFile("test_penalty.py", 41),
|
||||
TestFile("test_page_size.py", 60),
|
||||
TestFile("test_pytorch_sampling_backend.py", 66),
|
||||
|
||||
@@ -23,6 +23,7 @@ class TestHiddenState(CustomTestCase):
|
||||
model_path=model_path,
|
||||
random_seed=42,
|
||||
skip_tokenizer_init=True,
|
||||
enable_return_hidden_states=True,
|
||||
)
|
||||
outputs = engine.generate(
|
||||
input_ids=input_ids,
|
||||
@@ -96,6 +97,7 @@ class TestHiddenState(CustomTestCase):
|
||||
model_path=model_path,
|
||||
random_seed=42,
|
||||
skip_tokenizer_init=True,
|
||||
enable_return_hidden_states=True,
|
||||
)
|
||||
outputs_completion_first_round = engine.generate(
|
||||
input_ids=input_ids,
|
||||
|
||||
@@ -381,12 +381,14 @@ class TestGenerateReqInputNormalization(CustomTestCase):
|
||||
logprob_start_len=[10, 5],
|
||||
top_logprobs_num=[5, 3],
|
||||
token_ids_logprob=[[7, 8, 9], [4, 5, 6]],
|
||||
return_hidden_states=[False, False, True],
|
||||
)
|
||||
req.normalize_batch_and_arguments()
|
||||
self.assertEqual(req.return_logprob, [True, False])
|
||||
self.assertEqual(req.logprob_start_len, [10, 5])
|
||||
self.assertEqual(req.top_logprobs_num, [5, 3])
|
||||
self.assertEqual(req.token_ids_logprob, [[7, 8, 9], [4, 5, 6]])
|
||||
self.assertEqual(req.return_hidden_states, [False, False, True])
|
||||
|
||||
def test_custom_logit_processor_normalization(self):
|
||||
"""Test normalization of custom_logit_processor."""
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
"""
|
||||
python3 -m unittest test_openai_server.TestOpenAIServer.test_batch
|
||||
python3 -m unittest test_openai_server.TestOpenAIServer.test_completion
|
||||
|
||||
python3 -m unittest test_openai_server.TestOpenAIServer.test_completion_stream
|
||||
python3 -m unittest test_openai_server.TestOpenAIServer.test_chat_completion
|
||||
python3 -m unittest test_openai_server.TestOpenAIServer.test_chat_completion_stream
|
||||
"""
|
||||
|
||||
import json
|
||||
@@ -9,6 +11,7 @@ import re
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import openai
|
||||
import requests
|
||||
|
||||
@@ -137,27 +140,29 @@ class TestOpenAIServer(CustomTestCase):
|
||||
for response in generator:
|
||||
usage = response.usage
|
||||
if usage is not None:
|
||||
assert usage.prompt_tokens > 0
|
||||
assert usage.completion_tokens > 0
|
||||
assert usage.total_tokens > 0
|
||||
assert usage.prompt_tokens > 0, f"usage.prompt_tokens was zero"
|
||||
assert usage.completion_tokens > 0, f"usage.completion_tokens was zero"
|
||||
assert usage.total_tokens > 0, f"usage.total_tokens was zero"
|
||||
continue
|
||||
|
||||
index = response.choices[0].index
|
||||
is_first = is_firsts.get(index, True)
|
||||
|
||||
if logprobs:
|
||||
assert response.choices[0].logprobs
|
||||
assert isinstance(response.choices[0].logprobs.tokens[0], str)
|
||||
assert response.choices[0].logprobs, f"no logprobs in response"
|
||||
assert isinstance(
|
||||
response.choices[0].logprobs.tokens[0], str
|
||||
), f"{response.choices[0].logprobs.tokens[0]} is not a string"
|
||||
if not (is_first and echo):
|
||||
assert isinstance(
|
||||
response.choices[0].logprobs.top_logprobs[0], dict
|
||||
)
|
||||
), f"top_logprobs was not a dictionary"
|
||||
ret_num_top_logprobs = len(
|
||||
response.choices[0].logprobs.top_logprobs[0]
|
||||
)
|
||||
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map
|
||||
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
|
||||
assert ret_num_top_logprobs > 0
|
||||
assert ret_num_top_logprobs > 0, f"ret_num_top_logprobs was 0"
|
||||
|
||||
if is_first:
|
||||
if echo:
|
||||
@@ -165,8 +170,8 @@ class TestOpenAIServer(CustomTestCase):
|
||||
prompt
|
||||
), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {is_first}"
|
||||
is_firsts[index] = False
|
||||
assert response.id
|
||||
assert response.created
|
||||
assert response.id, f"no id in response"
|
||||
assert response.created, f"no created in response"
|
||||
|
||||
for index in [i for i in range(parallel_sample_num * num_choices)]:
|
||||
assert not is_firsts.get(
|
||||
@@ -231,27 +236,29 @@ class TestOpenAIServer(CustomTestCase):
|
||||
for response in generator:
|
||||
usage = response.usage
|
||||
if usage is not None:
|
||||
assert usage.prompt_tokens > 0
|
||||
assert usage.completion_tokens > 0
|
||||
assert usage.total_tokens > 0
|
||||
assert usage.prompt_tokens > 0, f"usage.prompt_tokens was zero"
|
||||
assert usage.completion_tokens > 0, f"usage.completion_tokens was zero"
|
||||
assert usage.total_tokens > 0, f"usage.total_tokens was zero"
|
||||
continue
|
||||
|
||||
index = response.choices[0].index
|
||||
data = response.choices[0].delta
|
||||
|
||||
if is_firsts.get(index, True):
|
||||
assert data.role == "assistant"
|
||||
assert (
|
||||
data.role == "assistant"
|
||||
), f"data.role was not 'assistant' for first chunk"
|
||||
is_firsts[index] = False
|
||||
continue
|
||||
|
||||
if logprobs:
|
||||
assert response.choices[0].logprobs
|
||||
assert response.choices[0].logprobs, f"logprobs was not returned"
|
||||
assert isinstance(
|
||||
response.choices[0].logprobs.content[0].top_logprobs[0].token, str
|
||||
)
|
||||
), f"top_logprobs token was not a string"
|
||||
assert isinstance(
|
||||
response.choices[0].logprobs.content[0].top_logprobs, list
|
||||
)
|
||||
), f"top_logprobs was not a list"
|
||||
ret_num_top_logprobs = len(
|
||||
response.choices[0].logprobs.content[0].top_logprobs
|
||||
)
|
||||
|
||||
356
test/srt/test_openai_server_hidden_states.py
Normal file
356
test/srt/test_openai_server_hidden_states.py
Normal file
@@ -0,0 +1,356 @@
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import unittest
|
||||
from abc import ABC
|
||||
|
||||
import numpy as np
|
||||
import openai
|
||||
import torch
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
|
||||
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
|
||||
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class BaseTestOpenAIServerWithHiddenStates(ABC):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.return_hidden_states = [False, True]
|
||||
cls.use_list_input = [True, False]
|
||||
cls.parallel_sample_nums = [1, 2]
|
||||
|
||||
def test_completion(self):
|
||||
for return_hidden_states in self.return_hidden_states:
|
||||
for use_list_input in self.use_list_input:
|
||||
for parallel_sample_num in self.parallel_sample_nums:
|
||||
self.run_completion(
|
||||
use_list_input,
|
||||
parallel_sample_num,
|
||||
return_hidden_states,
|
||||
)
|
||||
|
||||
def test_completion_stream(self):
|
||||
# parallel sampling and list input are not supported in streaming mode
|
||||
for return_hidden_states in self.return_hidden_states:
|
||||
for use_list_input in self.use_list_input:
|
||||
for parallel_sample_num in self.parallel_sample_nums:
|
||||
self.run_completion_stream(
|
||||
use_list_input,
|
||||
parallel_sample_num,
|
||||
return_hidden_states,
|
||||
)
|
||||
|
||||
def test_chat_completion(self):
|
||||
for return_hidden_states in self.return_hidden_states:
|
||||
for (
|
||||
parallel_sample_num
|
||||
) in (
|
||||
self.parallel_sample_nums
|
||||
): # parallel sample num 2 breaks in the adapter with a 400 for EAGLE
|
||||
self.run_chat_completion(parallel_sample_num, return_hidden_states)
|
||||
|
||||
def test_chat_completion_stream(self):
|
||||
for return_hidden_states in self.return_hidden_states:
|
||||
for (
|
||||
parallel_sample_num
|
||||
) in (
|
||||
self.parallel_sample_nums
|
||||
): # parallel sample num > 1 breaks in the adapter with a 400 for EAGLE
|
||||
self.run_chat_completion_stream(
|
||||
parallel_sample_num, return_hidden_states
|
||||
)
|
||||
|
||||
def run_completion(
|
||||
self,
|
||||
use_list_input,
|
||||
parallel_sample_num,
|
||||
return_hidden_states,
|
||||
):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
prompt = "The capital of France is"
|
||||
prompt_input = prompt
|
||||
|
||||
if use_list_input:
|
||||
prompt_arg = [prompt_input, prompt_input]
|
||||
num_choices = len(prompt_arg)
|
||||
else:
|
||||
prompt_arg = prompt_input
|
||||
num_choices = 1
|
||||
|
||||
response = client.completions.create(
|
||||
model=self.model,
|
||||
prompt=prompt_arg,
|
||||
temperature=0,
|
||||
max_tokens=32,
|
||||
n=parallel_sample_num,
|
||||
extra_body=dict(return_hidden_states=return_hidden_states),
|
||||
)
|
||||
|
||||
for choice in response.choices:
|
||||
assert hasattr(choice, "hidden_states") == return_hidden_states
|
||||
if return_hidden_states:
|
||||
assert choice.hidden_states is not None, "hidden_states was None"
|
||||
|
||||
def run_completion_stream(
|
||||
self,
|
||||
use_list_input,
|
||||
parallel_sample_num,
|
||||
return_hidden_states,
|
||||
):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
prompt = "The capital of France is"
|
||||
prompt_input = prompt
|
||||
num_prompt_tokens = len(self.tokenizer.encode(prompt))
|
||||
|
||||
if use_list_input:
|
||||
prompt_arg = [prompt_input, prompt_input]
|
||||
num_choices = len(prompt_arg)
|
||||
num_prompt_tokens *= 2
|
||||
else:
|
||||
prompt_arg = prompt_input
|
||||
num_choices = 1
|
||||
|
||||
generator = client.completions.create(
|
||||
model=self.model,
|
||||
prompt=prompt_arg,
|
||||
temperature=0,
|
||||
max_tokens=32,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
n=parallel_sample_num,
|
||||
extra_body=dict(return_hidden_states=return_hidden_states),
|
||||
)
|
||||
|
||||
hidden_states_list = []
|
||||
for response in generator:
|
||||
usage = response.usage
|
||||
for choice in response.choices:
|
||||
if hasattr(choice, "hidden_states"):
|
||||
assert return_hidden_states
|
||||
assert choice.hidden_states is not None
|
||||
hidden_states_list.append(choice.hidden_states)
|
||||
|
||||
if return_hidden_states:
|
||||
assert (
|
||||
len(hidden_states_list) == parallel_sample_num * num_choices
|
||||
), f"Expected {parallel_sample_num * num_choices} hidden states, got {len(hidden_states_list)}"
|
||||
else:
|
||||
assert (
|
||||
hidden_states_list == []
|
||||
), "hidden_states were returned and should not have been"
|
||||
|
||||
def run_chat_completion(self, parallel_sample_num, return_hidden_states):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
response = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the capital of France? Answer in a few words.",
|
||||
},
|
||||
],
|
||||
temperature=0,
|
||||
n=parallel_sample_num,
|
||||
extra_body=dict(return_hidden_states=return_hidden_states),
|
||||
)
|
||||
|
||||
for choice in response.choices:
|
||||
assert hasattr(choice, "hidden_states") == return_hidden_states
|
||||
if return_hidden_states:
|
||||
assert choice.hidden_states is not None, "hidden_states was None"
|
||||
|
||||
def run_chat_completion_stream(
|
||||
self, parallel_sample_num=1, return_hidden_states=False
|
||||
):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
generator = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{"role": "user", "content": "What is the capital of France?"},
|
||||
],
|
||||
temperature=0,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
n=parallel_sample_num,
|
||||
extra_body=dict(return_hidden_states=return_hidden_states),
|
||||
)
|
||||
|
||||
is_firsts = {}
|
||||
hidden_states_list = []
|
||||
|
||||
for response in generator:
|
||||
for choice in response.choices:
|
||||
if hasattr(choice.delta, "hidden_states"):
|
||||
assert return_hidden_states
|
||||
assert choice.delta.hidden_states is not None
|
||||
hidden_states_list.append(choice.delta.hidden_states)
|
||||
|
||||
if return_hidden_states:
|
||||
assert (
|
||||
len(hidden_states_list) == parallel_sample_num
|
||||
), f"Expected {parallel_sample_num} hidden states, got {len(hidden_states_list)}"
|
||||
else:
|
||||
assert (
|
||||
hidden_states_list == []
|
||||
), "hidden_states were returned and should not have been"
|
||||
|
||||
|
||||
class TestOpenAIServerWithHiddenStatesEnabled(
|
||||
CustomTestCase, BaseTestOpenAIServerWithHiddenStates
|
||||
):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
other_args=["--enable-return-hidden-states"],
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
||||
cls.return_hidden_states = [False, True]
|
||||
cls.use_list_input = [True, False]
|
||||
cls.parallel_sample_nums = [1, 2]
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
|
||||
class TestOpenAIServerWithHiddenStatesEnabledAndCUDAGraphDisabled(
|
||||
CustomTestCase, BaseTestOpenAIServerWithHiddenStates
|
||||
):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
other_args=["--enable-return-hidden-states", "--disable-cuda-graph"],
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
||||
cls.return_hidden_states = [False, True]
|
||||
cls.use_list_input = [True, False]
|
||||
cls.parallel_sample_nums = [1]
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
|
||||
class TestOpenAIServerWithEAGLEAndHiddenStatesEnabled(
|
||||
CustomTestCase, BaseTestOpenAIServerWithHiddenStates
|
||||
):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
cls.speculative_draft_model = DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
|
||||
cls.speculative_algorithm = "EAGLE"
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--speculative-algorithm",
|
||||
"EAGLE",
|
||||
"--speculative-draft-model-path",
|
||||
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
|
||||
"--speculative-num-steps",
|
||||
5,
|
||||
"--speculative-eagle-topk",
|
||||
8,
|
||||
"--speculative-num-draft-tokens",
|
||||
64,
|
||||
"--mem-fraction-static",
|
||||
0.7,
|
||||
"--chunked-prefill-size",
|
||||
128,
|
||||
"--max-running-requests",
|
||||
8,
|
||||
"--enable-return-hidden-states",
|
||||
],
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
cls.tokenizer = get_tokenizer(DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST)
|
||||
cls.return_hidden_states = [False, True]
|
||||
cls.use_list_input = [True, False]
|
||||
cls.parallel_sample_nums = [1]
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
|
||||
class TestOpenAIServerWithEAGLE3AndHiddenStatesEnabled(
|
||||
CustomTestCase, BaseTestOpenAIServerWithHiddenStates
|
||||
):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
cls.speculative_algorithm = "EAGLE3"
|
||||
cls.speculative_draft_model = "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B"
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--speculative-algorithm",
|
||||
cls.speculative_algorithm,
|
||||
"--speculative-draft-model-path",
|
||||
cls.speculative_draft_model,
|
||||
"--speculative-num-steps",
|
||||
5,
|
||||
"--speculative-eagle-topk",
|
||||
16,
|
||||
"--speculative-num-draft-tokens",
|
||||
64,
|
||||
"--mem-fraction-static",
|
||||
0.7,
|
||||
"--chunked-prefill-size",
|
||||
128,
|
||||
"--max-running-requests",
|
||||
8,
|
||||
"--dtype",
|
||||
"float16",
|
||||
"--enable-return-hidden-states",
|
||||
],
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
cls.tokenizer = get_tokenizer(cls.model)
|
||||
cls.return_hidden_states = [False, True]
|
||||
cls.use_list_input = [True, False]
|
||||
cls.parallel_sample_nums = [1]
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user