refactor(test): reorganize OpenAI test file structure (#7408)

This commit is contained in:
Chang Su
2025-06-21 19:37:48 -07:00
committed by GitHub
parent 1998ce4046
commit b7a2df0a44
27 changed files with 350 additions and 294 deletions

View File

@@ -0,0 +1,211 @@
import asyncio
import unittest
import openai
import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestCacheReport(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.min_cached = 5
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=300,
other_args=[
"--chunked-prefill-size=40",
"--enable-cache-report",
],
)
cls.client = openai.Client(api_key="EMPTY", base_url=f"{cls.base_url}/v1")
cls.aclient = openai.AsyncClient(api_key="EMPTY", base_url=f"{cls.base_url}/v1")
usage = cls.run_openai(cls, "1").usage
# we can assume that our request is of size 1, plus the total template size
# ideally we would like to know the begin size / end size of the template to be more precise
total_template_size = usage.prompt_tokens - 1
print(f"template size: {total_template_size}")
usage2 = cls.run_openai(cls, "2").usage
assert usage2.prompt_tokens_details.cached_tokens <= total_template_size
cls.min_cached = max(
usage2.prompt_tokens_details.cached_tokens,
total_template_size - usage2.prompt_tokens_details.cached_tokens,
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1):
response = requests.post(
self.base_url + "/generate",
# we use an uncommon start to minimise the chance that the cache is hit by chance
json={
"text": "_ The capital of France is",
"sampling_params": {
"temperature": 0 if n == 1 else 0.5,
"max_new_tokens": 128,
"n": n,
"stop_token_ids": [119690],
},
"stream": False,
"return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num,
"logprob_start_len": 0,
},
)
return response
def run_openai(self, message):
response = self.client.chat.completions.create(
model=self.model,
messages=[
# {"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": message},
],
temperature=0,
max_tokens=100,
)
return response
async def run_openai_async(self, message):
response = await self.aclient.chat.completions.create(
model=self.model,
messages=[
{"role": "user", "content": message},
],
temperature=0,
max_tokens=100,
)
return response
def cache_report_openai(self, message):
response = self.run_openai(message)
print(
f"openai first request cached_tokens: {int(response.usage.prompt_tokens_details.cached_tokens)}"
)
first_cached_tokens = int(response.usage.prompt_tokens_details.cached_tokens)
# assert int(response.usage.cached_tokens) == 0
assert first_cached_tokens <= self.min_cached
response = self.run_openai(message)
cached_tokens = int(response.usage.prompt_tokens_details.cached_tokens)
print(f"openai second request cached_tokens: {cached_tokens}")
assert cached_tokens > 0
assert cached_tokens == int(response.usage.prompt_tokens) - 1
return first_cached_tokens
async def cache_report_openai_async(self, message):
response = await self.run_openai_async(message)
cached_tokens = int(response.usage.prompt_tokens_details.cached_tokens)
prompt_tokens = int(response.usage.prompt_tokens)
return cached_tokens, prompt_tokens
def test_generate(self):
print("=" * 100)
response = self.run_decode()
# print(response.json())
cached_tokens = int(response.json()["meta_info"]["cached_tokens"])
print(f"sglang first request cached_tokens: {cached_tokens}")
print(
f"sglang first request prompt_tokens: {int(response.json()['meta_info']['prompt_tokens'])}"
)
# can't assure to be 0: depends on the initialisation request / if a template is used with the model
assert cached_tokens < self.min_cached
response = self.run_decode()
cached_tokens = int(response.json()["meta_info"]["cached_tokens"])
print(f"sglang second request cached_tokens: {cached_tokens}")
print(
f"sglang second request prompt_tokens: {int(response.json()['meta_info']['prompt_tokens'])}"
)
assert cached_tokens == int(response.json()["meta_info"]["prompt_tokens"]) - 1
def test_cache_split_prefill_openai(self):
print("=" * 100)
self.cache_report_openai(
"€ This is a very long and unique text that should not be already cached, the twist is"
" that it should be longer than the chunked-prefill-size, so it should be split among"
" several prefill requests. Still, it shouldn't be cached"
)
def test_cache_report_openai(self):
print("=" * 100)
# warm up the cache, for the template
self.run_openai("Introduce the capital of France.")
first_cached_tokens_1 = self.run_openai(
"How many sparrow do you need to lift a coconut?"
).usage.prompt_tokens_details.cached_tokens
usage_2 = self.run_openai("* sing something about cats").usage
first_cached_tokens_2 = usage_2.prompt_tokens_details.cached_tokens
# first request may not have 0 cached tokens, but if they only have the template in common they
# should be the same once the cache is warmed up
assert first_cached_tokens_1 == first_cached_tokens_2
resp = self.run_openai("* sing something about cats and dogs")
print(resp.usage)
resp = self.run_openai("* sing something about cats, please")
print(resp.usage)
assert (
resp.usage.prompt_tokens_details.cached_tokens
>= usage_2.prompt_tokens - self.min_cached
)
def test_cache_report_openai_async(self):
print("=" * 100)
async def run_test():
task0 = asyncio.create_task(
self.cache_report_openai_async(
"first request, to start the inference and let the next two request be started in the same batch"
)
)
await asyncio.sleep(0.05) # to force the first request to be started first
task1 = asyncio.create_task(
self.cache_report_openai_async(
"> can the same batch parallel request use the cache?"
)
)
task2 = asyncio.create_task(
self.cache_report_openai_async(
"> can the same batch parallel request use the cache?"
)
)
result0, result1, result2 = await asyncio.gather(task0, task1, task2)
cached_tokens0, prompt_tokens0 = result0
cached_tokens1, prompt_tokens1 = result1
cached_tokens2, prompt_tokens2 = result2
print(
f"Async request 0 - Cached tokens: {cached_tokens0}, Prompt tokens: {prompt_tokens0}"
)
print(
f"Async request 1 - Cached tokens: {cached_tokens1}, Prompt tokens: {prompt_tokens1}"
)
print(
f"Async request 2 - Cached tokens: {cached_tokens2}, Prompt tokens: {prompt_tokens2}"
)
# Assert that no requests used the cache (becausefirst is alone, and the next two are in the same batch)
# If a new optimisation limiting starting request with same prefix at the same time was added
# to maximise the cache hit, this would not be true
assert cached_tokens1 == cached_tokens2 == cached_tokens0
asyncio.run(run_test())
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,188 @@
"""
Usage:
python3 -m unittest openai_server.features.test_enable_thinking.TestEnableThinking.test_chat_completion_with_reasoning
python3 -m unittest openai_server.features.test_enable_thinking.TestEnableThinking.test_chat_completion_without_reasoning
python3 -m unittest openai_server.features.test_enable_thinking.TestEnableThinking.test_stream_chat_completion_with_reasoning
python3 -m unittest openai_server.features.test_enable_thinking.TestEnableThinking.test_stream_chat_completion_without_reasoning
"""
import asyncio
import json
import os
import sys
import time
import unittest
import openai
import requests
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_ENABLE_THINKING_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestEnableThinking(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_ENABLE_THINKING_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-1234"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=[
"--reasoning-parser",
"qwen3",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_chat_completion_with_reasoning(self):
# Test non-streaming with "enable_thinking": True, reasoning_content should not be empty
client = requests.post(
f"{self.base_url}/v1/chat/completions",
headers={"Authorization": f"Bearer {self.api_key}"},
json={
"model": self.model,
"messages": [{"role": "user", "content": "Hello"}],
"temperature": 0,
"separate_reasoning": True,
"chat_template_kwargs": {"enable_thinking": True},
},
)
self.assertEqual(client.status_code, 200, f"Failed with: {client.text}")
data = client.json()
self.assertIn("choices", data)
self.assertTrue(len(data["choices"]) > 0)
self.assertIn("message", data["choices"][0])
self.assertIn("reasoning_content", data["choices"][0]["message"])
self.assertIsNotNone(data["choices"][0]["message"]["reasoning_content"])
def test_chat_completion_without_reasoning(self):
# Test non-streaming with "enable_thinking": False, reasoning_content should be empty
client = requests.post(
f"{self.base_url}/v1/chat/completions",
headers={"Authorization": f"Bearer {self.api_key}"},
json={
"model": self.model,
"messages": [{"role": "user", "content": "Hello"}],
"temperature": 0,
"separate_reasoning": True,
"chat_template_kwargs": {"enable_thinking": False},
},
)
self.assertEqual(client.status_code, 200, f"Failed with: {client.text}")
data = client.json()
self.assertIn("choices", data)
self.assertTrue(len(data["choices"]) > 0)
self.assertIn("message", data["choices"][0])
if "reasoning_content" in data["choices"][0]["message"]:
self.assertIsNone(data["choices"][0]["message"]["reasoning_content"])
def test_stream_chat_completion_with_reasoning(self):
# Test streaming with "enable_thinking": True, reasoning_content should not be empty
response = requests.post(
f"{self.base_url}/v1/chat/completions",
headers={"Authorization": f"Bearer {self.api_key}"},
json={
"model": self.model,
"messages": [{"role": "user", "content": "Hello"}],
"temperature": 0,
"separate_reasoning": True,
"stream": True,
"chat_template_kwargs": {"enable_thinking": True},
},
stream=True,
)
self.assertEqual(response.status_code, 200, f"Failed with: {response.text}")
has_reasoning = False
has_content = False
print("\n=== Stream With Reasoning ===")
for line in response.iter_lines():
if line:
line = line.decode("utf-8")
if line.startswith("data:") and not line.startswith("data: [DONE]"):
data = json.loads(line[6:])
if "choices" in data and len(data["choices"]) > 0:
delta = data["choices"][0].get("delta", {})
if "reasoning_content" in delta and delta["reasoning_content"]:
has_reasoning = True
if "content" in delta and delta["content"]:
has_content = True
self.assertTrue(
has_reasoning,
"The reasoning content is not included in the stream response",
)
self.assertTrue(
has_content, "The stream response does not contain normal content"
)
def test_stream_chat_completion_without_reasoning(self):
# Test streaming with "enable_thinking": False, reasoning_content should be empty
response = requests.post(
f"{self.base_url}/v1/chat/completions",
headers={"Authorization": f"Bearer {self.api_key}"},
json={
"model": self.model,
"messages": [{"role": "user", "content": "Hello"}],
"temperature": 0,
"separate_reasoning": True,
"stream": True,
"chat_template_kwargs": {"enable_thinking": False},
},
stream=True,
)
self.assertEqual(response.status_code, 200, f"Failed with: {response.text}")
has_reasoning = False
has_content = False
print("\n=== Stream Without Reasoning ===")
for line in response.iter_lines():
if line:
line = line.decode("utf-8")
if line.startswith("data:") and not line.startswith("data: [DONE]"):
data = json.loads(line[6:])
if "choices" in data and len(data["choices"]) > 0:
delta = data["choices"][0].get("delta", {})
if "reasoning_content" in delta and delta["reasoning_content"]:
has_reasoning = True
if "content" in delta and delta["content"]:
has_content = True
self.assertFalse(
has_reasoning,
"The reasoning content should not be included in the stream response",
)
self.assertTrue(
has_content, "The stream response does not contain normal content"
)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,153 @@
"""
python3 -m unittest openai_server.features.test_json_constrained.TestJSONConstrainedOutlinesBackend.test_json_generate
python3 -m unittest openai_server.features.test_json_constrained.TestJSONConstrainedXGrammarBackend.test_json_generate
python3 -m unittest openai_server.features.test_json_constrained.TestJSONConstrainedLLGuidanceBackend.test_json_generate
"""
import json
import unittest
from concurrent.futures import ThreadPoolExecutor
import openai
import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
def setup_class(cls, backend: str):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.json_schema = json.dumps(
{
"type": "object",
"properties": {
"name": {"type": "string", "pattern": "^[\\w]+$"},
"population": {"type": "integer"},
},
"required": ["name", "population"],
"additionalProperties": False,
}
)
other_args = [
"--max-running-requests",
"10",
"--grammar-backend",
backend,
]
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
)
class TestJSONConstrainedOutlinesBackend(CustomTestCase):
@classmethod
def setUpClass(cls):
setup_class(cls, backend="outlines")
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def run_decode(self, json_schema, return_logprob=False, top_logprobs_num=0, n=1):
response = requests.post(
self.base_url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0 if n == 1 else 0.5,
"max_new_tokens": 128,
"n": n,
"stop_token_ids": [119690],
"json_schema": json_schema,
},
"stream": False,
"return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num,
"logprob_start_len": 0,
},
)
ret = response.json()
print(json.dumps(ret))
print("=" * 100)
if not json_schema or json_schema == "INVALID":
return
# Make sure the json output is valid
try:
js_obj = json.loads(ret["text"])
except (TypeError, json.decoder.JSONDecodeError):
raise
self.assertIsInstance(js_obj["name"], str)
self.assertIsInstance(js_obj["population"], int)
def test_json_generate(self):
self.run_decode(json_schema=self.json_schema)
def test_json_invalid(self):
self.run_decode(json_schema="INVALID")
def test_json_openai(self):
client = openai.Client(api_key="EMPTY", base_url=f"{self.base_url}/v1")
response = client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{
"role": "user",
"content": "Introduce the capital of France. Return in a JSON format.",
},
],
temperature=0,
max_tokens=128,
response_format={
"type": "json_schema",
"json_schema": {"name": "foo", "schema": json.loads(self.json_schema)},
},
)
text = response.choices[0].message.content
try:
js_obj = json.loads(text)
except (TypeError, json.decoder.JSONDecodeError):
print("JSONDecodeError", text)
raise
self.assertIsInstance(js_obj["name"], str)
self.assertIsInstance(js_obj["population"], int)
def test_mix_json_and_other(self):
json_schemas = [None, None, self.json_schema, self.json_schema] * 10
with ThreadPoolExecutor(len(json_schemas)) as executor:
list(executor.map(self.run_decode, json_schemas))
class TestJSONConstrainedXGrammarBackend(TestJSONConstrainedOutlinesBackend):
@classmethod
def setUpClass(cls):
setup_class(cls, backend="xgrammar")
class TestJSONConstrainedLLGuidanceBackend(TestJSONConstrainedOutlinesBackend):
@classmethod
def setUpClass(cls):
setup_class(cls, backend="llguidance")
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,137 @@
"""
python3 -m unittest openai_server.features.test_json_mode.TestJSONModeOutlines.test_json_mode_response
python3 -m unittest openai_server.features.test_json_mode.TestJSONModeOutlines.test_json_mode_with_streaming
python3 -m unittest openai_server.features.test_json_mode.TestJSONModeXGrammar.test_json_mode_response
python3 -m unittest openai_server.features.test_json_mode.TestJSONModeXGrammar.test_json_mode_with_streaming
python3 -m unittest openai_server.features.test_json_mode.TestJSONModeLLGuidance.test_json_mode_response
python3 -m unittest openai_server.features.test_json_mode.TestJSONModeLLGuidance.test_json_mode_with_streaming
"""
import json
import unittest
import openai
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)
def setup_class(cls, backend):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
other_args = [
"--max-running-requests",
"10",
"--grammar-backend",
backend,
]
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
)
cls.client = openai.Client(api_key="EMPTY", base_url=f"{cls.base_url}/v1")
class TestJSONModeOutlines(unittest.TestCase):
@classmethod
def setUpClass(cls):
setup_class(cls, "outlines")
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_json_mode_response(self):
"""Test that response_format json_object (also known as "json mode") produces valid JSON, even without a system prompt that mentions JSON."""
response = self.client.chat.completions.create(
model=self.model,
messages=[
# We are deliberately omitting "That produces JSON" or similar phrases from the assistant prompt so that we don't have misleading test results
{
"role": "system",
"content": "You are a helpful AI assistant that gives a short answer.",
},
{"role": "user", "content": "What is the capital of Bulgaria?"},
],
temperature=0,
max_tokens=128,
response_format={"type": "json_object"},
)
text = response.choices[0].message.content
print(f"Response ({len(text)} characters): {text}")
# Verify the response is valid JSON
try:
js_obj = json.loads(text)
except json.JSONDecodeError as e:
self.fail(f"Response is not valid JSON. Error: {e}. Response: {text}")
# Verify it's actually an object (dict)
self.assertIsInstance(js_obj, dict, f"Response is not a JSON object: {text}")
def test_json_mode_with_streaming(self):
"""Test that streaming with json_object response (also known as "json mode") format works correctly, even without a system prompt that mentions JSON."""
stream = self.client.chat.completions.create(
model=self.model,
messages=[
# We are deliberately omitting "That produces JSON" or similar phrases from the assistant prompt so that we don't have misleading test results
{
"role": "system",
"content": "You are a helpful AI assistant that gives a short answer.",
},
{"role": "user", "content": "What is the capital of Bulgaria?"},
],
temperature=0,
max_tokens=128,
response_format={"type": "json_object"},
stream=True,
)
# Collect all chunks
chunks = []
for chunk in stream:
if chunk.choices[0].delta.content is not None:
chunks.append(chunk.choices[0].delta.content)
full_response = "".join(chunks)
print(
f"Concatenated Response ({len(full_response)} characters): {full_response}"
)
# Verify the combined response is valid JSON
try:
js_obj = json.loads(full_response)
except json.JSONDecodeError as e:
self.fail(
f"Streamed response is not valid JSON. Error: {e}. Response: {full_response}"
)
self.assertIsInstance(js_obj, dict)
class TestJSONModeXGrammar(TestJSONModeOutlines):
@classmethod
def setUpClass(cls):
setup_class(cls, backend="xgrammar")
class TestJSONModeLLGuidance(TestJSONModeOutlines):
@classmethod
def setUpClass(cls):
setup_class(cls, backend="llguidance")
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,98 @@
import re
import openai
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
# -------------------------------------------------------------------------
# EBNF Test Class: TestOpenAIServerEBNF
# Launches the server with xgrammar, has only EBNF tests
# -------------------------------------------------------------------------
class TestOpenAIServerEBNF(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
# passing xgrammar specifically
other_args = ["--grammar-backend", "xgrammar"]
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=other_args,
)
cls.base_url += "/v1"
cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_ebnf(self):
"""
Ensure we can pass `ebnf` to the local openai server
and that it enforces the grammar.
"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
ebnf_grammar = r"""
root ::= "Hello" | "Hi" | "Hey"
"""
pattern = re.compile(r"^(Hello|Hi|Hey)[.!?]*\s*$")
response = client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You are a helpful EBNF test bot."},
{"role": "user", "content": "Say a greeting (Hello, Hi, or Hey)."},
],
temperature=0,
max_tokens=32,
extra_body={"ebnf": ebnf_grammar},
)
text = response.choices[0].message.content.strip()
self.assertTrue(len(text) > 0, "Got empty text from EBNF generation")
self.assertRegex(text, pattern, f"Text '{text}' doesn't match EBNF choices")
def test_ebnf_strict_json(self):
"""
A stricter EBNF that produces exactly {"name":"Alice"} format
with no trailing punctuation or extra fields.
"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
ebnf_grammar = r"""
root ::= "{" pair "}"
pair ::= "\"name\"" ":" string
string ::= "\"" [A-Za-z]+ "\""
"""
pattern = re.compile(r'^\{"name":"[A-Za-z]+"\}$')
response = client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "EBNF mini-JSON generator."},
{
"role": "user",
"content": "Generate single key JSON with only letters.",
},
],
temperature=0,
max_tokens=64,
extra_body={"ebnf": ebnf_grammar},
)
text = response.choices[0].message.content.strip()
self.assertTrue(len(text) > 0, "Got empty text from EBNF strict JSON test")
self.assertRegex(
text, pattern, f"Text '{text}' not matching the EBNF strict JSON shape"
)

View File

@@ -0,0 +1,356 @@
import json
import re
import time
import unittest
from abc import ABC
import numpy as np
import openai
import torch
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class BaseTestOpenAIServerWithHiddenStates(ABC):
@classmethod
def setUpClass(cls):
cls.return_hidden_states = [False, True]
cls.use_list_input = [True, False]
cls.parallel_sample_nums = [1, 2]
def test_completion(self):
for return_hidden_states in self.return_hidden_states:
for use_list_input in self.use_list_input:
for parallel_sample_num in self.parallel_sample_nums:
self.run_completion(
use_list_input,
parallel_sample_num,
return_hidden_states,
)
def test_completion_stream(self):
# parallel sampling and list input are not supported in streaming mode
for return_hidden_states in self.return_hidden_states:
for use_list_input in self.use_list_input:
for parallel_sample_num in self.parallel_sample_nums:
self.run_completion_stream(
use_list_input,
parallel_sample_num,
return_hidden_states,
)
def test_chat_completion(self):
for return_hidden_states in self.return_hidden_states:
for (
parallel_sample_num
) in (
self.parallel_sample_nums
): # parallel sample num 2 breaks in the adapter with a 400 for EAGLE
self.run_chat_completion(parallel_sample_num, return_hidden_states)
def test_chat_completion_stream(self):
for return_hidden_states in self.return_hidden_states:
for (
parallel_sample_num
) in (
self.parallel_sample_nums
): # parallel sample num > 1 breaks in the adapter with a 400 for EAGLE
self.run_chat_completion_stream(
parallel_sample_num, return_hidden_states
)
def run_completion(
self,
use_list_input,
parallel_sample_num,
return_hidden_states,
):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
prompt = "The capital of France is"
prompt_input = prompt
if use_list_input:
prompt_arg = [prompt_input, prompt_input]
num_choices = len(prompt_arg)
else:
prompt_arg = prompt_input
num_choices = 1
response = client.completions.create(
model=self.model,
prompt=prompt_arg,
temperature=0,
max_tokens=32,
n=parallel_sample_num,
extra_body=dict(return_hidden_states=return_hidden_states),
)
for choice in response.choices:
assert hasattr(choice, "hidden_states") == return_hidden_states
if return_hidden_states:
assert choice.hidden_states is not None, "hidden_states was None"
def run_completion_stream(
self,
use_list_input,
parallel_sample_num,
return_hidden_states,
):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
prompt = "The capital of France is"
prompt_input = prompt
num_prompt_tokens = len(self.tokenizer.encode(prompt))
if use_list_input:
prompt_arg = [prompt_input, prompt_input]
num_choices = len(prompt_arg)
num_prompt_tokens *= 2
else:
prompt_arg = prompt_input
num_choices = 1
generator = client.completions.create(
model=self.model,
prompt=prompt_arg,
temperature=0,
max_tokens=32,
stream=True,
stream_options={"include_usage": True},
n=parallel_sample_num,
extra_body=dict(return_hidden_states=return_hidden_states),
)
hidden_states_list = []
for response in generator:
usage = response.usage
for choice in response.choices:
if hasattr(choice, "hidden_states"):
assert return_hidden_states
assert choice.hidden_states is not None
hidden_states_list.append(choice.hidden_states)
if return_hidden_states:
assert (
len(hidden_states_list) == parallel_sample_num * num_choices
), f"Expected {parallel_sample_num * num_choices} hidden states, got {len(hidden_states_list)}"
else:
assert (
hidden_states_list == []
), "hidden_states were returned and should not have been"
def run_chat_completion(self, parallel_sample_num, return_hidden_states):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{
"role": "user",
"content": "What is the capital of France? Answer in a few words.",
},
],
temperature=0,
n=parallel_sample_num,
extra_body=dict(return_hidden_states=return_hidden_states),
)
for choice in response.choices:
assert hasattr(choice, "hidden_states") == return_hidden_states
if return_hidden_states:
assert choice.hidden_states is not None, "hidden_states was None"
def run_chat_completion_stream(
self, parallel_sample_num=1, return_hidden_states=False
):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
generator = client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "What is the capital of France?"},
],
temperature=0,
stream=True,
stream_options={"include_usage": True},
n=parallel_sample_num,
extra_body=dict(return_hidden_states=return_hidden_states),
)
is_firsts = {}
hidden_states_list = []
for response in generator:
for choice in response.choices:
if hasattr(choice.delta, "hidden_states"):
assert return_hidden_states
assert choice.delta.hidden_states is not None
hidden_states_list.append(choice.delta.hidden_states)
if return_hidden_states:
assert (
len(hidden_states_list) == parallel_sample_num
), f"Expected {parallel_sample_num} hidden states, got {len(hidden_states_list)}"
else:
assert (
hidden_states_list == []
), "hidden_states were returned and should not have been"
class TestOpenAIServerWithHiddenStatesEnabled(
CustomTestCase, BaseTestOpenAIServerWithHiddenStates
):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=["--enable-return-hidden-states"],
)
cls.base_url += "/v1"
cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
cls.return_hidden_states = [False, True]
cls.use_list_input = [True, False]
cls.parallel_sample_nums = [1, 2]
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
class TestOpenAIServerWithHiddenStatesEnabledAndCUDAGraphDisabled(
CustomTestCase, BaseTestOpenAIServerWithHiddenStates
):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=["--enable-return-hidden-states", "--disable-cuda-graph"],
)
cls.base_url += "/v1"
cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
cls.return_hidden_states = [False, True]
cls.use_list_input = [True, False]
cls.parallel_sample_nums = [1]
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
class TestOpenAIServerWithEAGLEAndHiddenStatesEnabled(
CustomTestCase, BaseTestOpenAIServerWithHiddenStates
):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.speculative_draft_model = DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
cls.speculative_algorithm = "EAGLE"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--speculative-algorithm",
"EAGLE",
"--speculative-draft-model-path",
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"--speculative-num-steps",
5,
"--speculative-eagle-topk",
8,
"--speculative-num-draft-tokens",
64,
"--mem-fraction-static",
0.7,
"--chunked-prefill-size",
128,
"--max-running-requests",
8,
"--enable-return-hidden-states",
],
)
cls.base_url += "/v1"
cls.tokenizer = get_tokenizer(DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST)
cls.return_hidden_states = [False, True]
cls.use_list_input = [True, False]
cls.parallel_sample_nums = [1]
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
class TestOpenAIServerWithEAGLE3AndHiddenStatesEnabled(
CustomTestCase, BaseTestOpenAIServerWithHiddenStates
):
@classmethod
def setUpClass(cls):
cls.model = "meta-llama/Llama-3.1-8B-Instruct"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.speculative_algorithm = "EAGLE3"
cls.speculative_draft_model = "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--speculative-algorithm",
cls.speculative_algorithm,
"--speculative-draft-model-path",
cls.speculative_draft_model,
"--speculative-num-steps",
5,
"--speculative-eagle-topk",
16,
"--speculative-num-draft-tokens",
64,
"--mem-fraction-static",
0.7,
"--chunked-prefill-size",
128,
"--max-running-requests",
8,
"--dtype",
"float16",
"--enable-return-hidden-states",
],
)
cls.base_url += "/v1"
cls.tokenizer = get_tokenizer(cls.model)
cls.return_hidden_states = [False, True]
cls.use_list_input = [True, False]
cls.parallel_sample_nums = [1]
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,343 @@
"""
Usage:
python3 -m unittest openai_server.features.test_reasoning_content.TestReasoningContentAPI.test_streaming_separate_reasoning_false
python3 -m unittest openai_server.features.test_reasoning_content.TestReasoningContentAPI.test_streaming_separate_reasoning_true
python3 -m unittest openai_server.features.test_reasoning_content.TestReasoningContentAPI.test_streaming_separate_reasoning_true_stream_reasoning_false
python3 -m unittest openai_server.features.test_reasoning_content.TestReasoningContentAPI.test_nonstreaming_separate_reasoning_false
python3 -m unittest openai_server.features.test_reasoning_content.TestReasoningContentAPI.test_nonstreaming_separate_reasoning_true
python3 -m unittest openai_server.features.test_reasoning_content.TestReasoningContentStartup.test_nonstreaming
python3 -m unittest openai_server.features.test_reasoning_content.TestReasoningContentStartup.test_streaming
"""
import json
import unittest
import openai
import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_REASONING_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestReasoningContentAPI(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_REASONING_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-1234"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=[
"--reasoning-parser",
"deepseek-r1",
],
)
cls.base_url += "/v1"
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_streaming_separate_reasoning_false(self):
# Test streaming with separate_reasoning=False, reasoning_content should be empty
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
payload = {
"model": self.model,
"messages": [
{
"role": "user",
"content": "What is 1+3?",
}
],
"max_tokens": 100,
"stream": True,
"extra_body": {"separate_reasoning": False},
}
response = client.chat.completions.create(**payload)
reasoning_content = ""
content = ""
for chunk in response:
if chunk.choices[0].delta.content:
content += chunk.choices[0].delta.content
elif chunk.choices[0].delta.reasoning_content:
reasoning_content += chunk.choices[0].delta.reasoning_content
assert len(reasoning_content) == 0
assert len(content) > 0
def test_streaming_separate_reasoning_true(self):
# Test streaming with separate_reasoning=True, reasoning_content should not be empty
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
payload = {
"model": self.model,
"messages": [
{
"role": "user",
"content": "What is 1+3?",
}
],
"max_tokens": 100,
"stream": True,
"extra_body": {"separate_reasoning": True},
}
response = client.chat.completions.create(**payload)
reasoning_content = ""
content = ""
for chunk in response:
if chunk.choices[0].delta.content:
content += chunk.choices[0].delta.content
elif chunk.choices[0].delta.reasoning_content:
reasoning_content += chunk.choices[0].delta.reasoning_content
assert len(reasoning_content) > 0
assert len(content) > 0
def test_streaming_separate_reasoning_true_stream_reasoning_false(self):
# Test streaming with separate_reasoning=True, reasoning_content should not be empty
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
payload = {
"model": self.model,
"messages": [
{
"role": "user",
"content": "What is 1+3?",
}
],
"max_tokens": 100,
"stream": True,
"extra_body": {"separate_reasoning": True, "stream_reasoning": False},
}
response = client.chat.completions.create(**payload)
reasoning_content = ""
content = ""
first_chunk = False
for chunk in response:
if chunk.choices[0].delta.reasoning_content:
reasoning_content = chunk.choices[0].delta.reasoning_content
first_chunk = True
if chunk.choices[0].delta.content:
content += chunk.choices[0].delta.content
if not first_chunk:
reasoning_content = chunk.choices[0].delta.reasoning_content
first_chunk = True
if not first_chunk:
assert (
not chunk.choices[0].delta.reasoning_content
or len(chunk.choices[0].delta.reasoning_content) == 0
)
assert len(reasoning_content) > 0
assert len(content) > 0
def test_nonstreaming_separate_reasoning_false(self):
# Test non-streaming with separate_reasoning=False, reasoning_content should be empty
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
payload = {
"model": self.model,
"messages": [
{
"role": "user",
"content": "What is 1+3?",
}
],
"max_tokens": 100,
"extra_body": {"separate_reasoning": False},
}
response = client.chat.completions.create(**payload)
assert (
not response.choices[0].message.reasoning_content
or len(response.choices[0].message.reasoning_content) == 0
)
assert len(response.choices[0].message.content) > 0
def test_nonstreaming_separate_reasoning_true(self):
# Test non-streaming with separate_reasoning=True, reasoning_content should not be empty
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
payload = {
"model": self.model,
"messages": [
{
"role": "user",
"content": "What is 1+3?",
}
],
"max_tokens": 100,
"extra_body": {"separate_reasoning": True},
}
response = client.chat.completions.create(**payload)
assert len(response.choices[0].message.reasoning_content) > 0
assert len(response.choices[0].message.content) > 0
class TestReasoningContentWithoutParser(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_REASONING_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-1234"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=[], # No reasoning parser
)
cls.base_url += "/v1"
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_streaming_separate_reasoning_false(self):
# Test streaming with separate_reasoning=False, reasoning_content should be empty
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
payload = {
"model": self.model,
"messages": [
{
"role": "user",
"content": "What is 1+3?",
}
],
"max_tokens": 100,
"stream": True,
"extra_body": {"separate_reasoning": False},
}
response = client.chat.completions.create(**payload)
reasoning_content = ""
content = ""
for chunk in response:
if chunk.choices[0].delta.content:
content += chunk.choices[0].delta.content
elif chunk.choices[0].delta.reasoning_content:
reasoning_content += chunk.choices[0].delta.reasoning_content
assert len(reasoning_content) == 0
assert len(content) > 0
def test_streaming_separate_reasoning_true(self):
# Test streaming with separate_reasoning=True, reasoning_content should not be empty
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
payload = {
"model": self.model,
"messages": [
{
"role": "user",
"content": "What is 1+3?",
}
],
"max_tokens": 100,
"stream": True,
"extra_body": {"separate_reasoning": True},
}
response = client.chat.completions.create(**payload)
reasoning_content = ""
content = ""
for chunk in response:
if chunk.choices[0].delta.content:
content += chunk.choices[0].delta.content
elif chunk.choices[0].delta.reasoning_content:
reasoning_content += chunk.choices[0].delta.reasoning_content
assert len(reasoning_content) == 0
assert len(content) > 0
def test_streaming_separate_reasoning_true_stream_reasoning_false(self):
# Test streaming with separate_reasoning=True, reasoning_content should not be empty
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
payload = {
"model": self.model,
"messages": [
{
"role": "user",
"content": "What is 1+3?",
}
],
"max_tokens": 100,
"stream": True,
"extra_body": {"separate_reasoning": True, "stream_reasoning": False},
}
response = client.chat.completions.create(**payload)
reasoning_content = ""
content = ""
first_chunk = False
for chunk in response:
if chunk.choices[0].delta.reasoning_content:
reasoning_content = chunk.choices[0].delta.reasoning_content
first_chunk = True
if chunk.choices[0].delta.content:
content += chunk.choices[0].delta.content
if not first_chunk:
reasoning_content = chunk.choices[0].delta.reasoning_content
first_chunk = True
if not first_chunk:
assert (
not chunk.choices[0].delta.reasoning_content
or len(chunk.choices[0].delta.reasoning_content) == 0
)
assert not reasoning_content or len(reasoning_content) == 0
assert len(content) > 0
def test_nonstreaming_separate_reasoning_false(self):
# Test non-streaming with separate_reasoning=False, reasoning_content should be empty
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
payload = {
"model": self.model,
"messages": [
{
"role": "user",
"content": "What is 1+3?",
}
],
"max_tokens": 100,
"extra_body": {"separate_reasoning": False},
}
response = client.chat.completions.create(**payload)
assert (
not response.choices[0].message.reasoning_content
or len(response.choices[0].message.reasoning_content) == 0
)
assert len(response.choices[0].message.content) > 0
def test_nonstreaming_separate_reasoning_true(self):
# Test non-streaming with separate_reasoning=True, reasoning_content should not be empty
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
payload = {
"model": self.model,
"messages": [
{
"role": "user",
"content": "What is 1+3?",
}
],
"max_tokens": 100,
"extra_body": {"separate_reasoning": True},
}
response = client.chat.completions.create(**payload)
assert (
not response.choices[0].message.reasoning_content
or len(response.choices[0].message.reasoning_content) == 0
)
assert len(response.choices[0].message.content) > 0
if __name__ == "__main__":
unittest.main()