refactor(test): reorganize OpenAI test file structure (#7408)
This commit is contained in:
0
test/srt/openai_server/features/__init__.py
Normal file
0
test/srt/openai_server/features/__init__.py
Normal file
211
test/srt/openai_server/features/test_cache_report.py
Normal file
211
test/srt/openai_server/features/test_cache_report.py
Normal file
@@ -0,0 +1,211 @@
|
||||
import asyncio
|
||||
import unittest
|
||||
|
||||
import openai
|
||||
import requests
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestCacheReport(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.min_cached = 5
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=300,
|
||||
other_args=[
|
||||
"--chunked-prefill-size=40",
|
||||
"--enable-cache-report",
|
||||
],
|
||||
)
|
||||
cls.client = openai.Client(api_key="EMPTY", base_url=f"{cls.base_url}/v1")
|
||||
cls.aclient = openai.AsyncClient(api_key="EMPTY", base_url=f"{cls.base_url}/v1")
|
||||
|
||||
usage = cls.run_openai(cls, "1").usage
|
||||
# we can assume that our request is of size 1, plus the total template size
|
||||
# ideally we would like to know the begin size / end size of the template to be more precise
|
||||
total_template_size = usage.prompt_tokens - 1
|
||||
print(f"template size: {total_template_size}")
|
||||
usage2 = cls.run_openai(cls, "2").usage
|
||||
assert usage2.prompt_tokens_details.cached_tokens <= total_template_size
|
||||
cls.min_cached = max(
|
||||
usage2.prompt_tokens_details.cached_tokens,
|
||||
total_template_size - usage2.prompt_tokens_details.cached_tokens,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1):
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
# we use an uncommon start to minimise the chance that the cache is hit by chance
|
||||
json={
|
||||
"text": "_ The capital of France is",
|
||||
"sampling_params": {
|
||||
"temperature": 0 if n == 1 else 0.5,
|
||||
"max_new_tokens": 128,
|
||||
"n": n,
|
||||
"stop_token_ids": [119690],
|
||||
},
|
||||
"stream": False,
|
||||
"return_logprob": return_logprob,
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
"logprob_start_len": 0,
|
||||
},
|
||||
)
|
||||
return response
|
||||
|
||||
def run_openai(self, message):
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
# {"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{"role": "user", "content": message},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=100,
|
||||
)
|
||||
return response
|
||||
|
||||
async def run_openai_async(self, message):
|
||||
response = await self.aclient.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "user", "content": message},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=100,
|
||||
)
|
||||
return response
|
||||
|
||||
def cache_report_openai(self, message):
|
||||
response = self.run_openai(message)
|
||||
print(
|
||||
f"openai first request cached_tokens: {int(response.usage.prompt_tokens_details.cached_tokens)}"
|
||||
)
|
||||
first_cached_tokens = int(response.usage.prompt_tokens_details.cached_tokens)
|
||||
# assert int(response.usage.cached_tokens) == 0
|
||||
assert first_cached_tokens <= self.min_cached
|
||||
response = self.run_openai(message)
|
||||
cached_tokens = int(response.usage.prompt_tokens_details.cached_tokens)
|
||||
print(f"openai second request cached_tokens: {cached_tokens}")
|
||||
assert cached_tokens > 0
|
||||
assert cached_tokens == int(response.usage.prompt_tokens) - 1
|
||||
return first_cached_tokens
|
||||
|
||||
async def cache_report_openai_async(self, message):
|
||||
response = await self.run_openai_async(message)
|
||||
cached_tokens = int(response.usage.prompt_tokens_details.cached_tokens)
|
||||
prompt_tokens = int(response.usage.prompt_tokens)
|
||||
return cached_tokens, prompt_tokens
|
||||
|
||||
def test_generate(self):
|
||||
print("=" * 100)
|
||||
response = self.run_decode()
|
||||
# print(response.json())
|
||||
cached_tokens = int(response.json()["meta_info"]["cached_tokens"])
|
||||
print(f"sglang first request cached_tokens: {cached_tokens}")
|
||||
print(
|
||||
f"sglang first request prompt_tokens: {int(response.json()['meta_info']['prompt_tokens'])}"
|
||||
)
|
||||
# can't assure to be 0: depends on the initialisation request / if a template is used with the model
|
||||
assert cached_tokens < self.min_cached
|
||||
response = self.run_decode()
|
||||
cached_tokens = int(response.json()["meta_info"]["cached_tokens"])
|
||||
print(f"sglang second request cached_tokens: {cached_tokens}")
|
||||
print(
|
||||
f"sglang second request prompt_tokens: {int(response.json()['meta_info']['prompt_tokens'])}"
|
||||
)
|
||||
assert cached_tokens == int(response.json()["meta_info"]["prompt_tokens"]) - 1
|
||||
|
||||
def test_cache_split_prefill_openai(self):
|
||||
print("=" * 100)
|
||||
self.cache_report_openai(
|
||||
"€ This is a very long and unique text that should not be already cached, the twist is"
|
||||
" that it should be longer than the chunked-prefill-size, so it should be split among"
|
||||
" several prefill requests. Still, it shouldn't be cached"
|
||||
)
|
||||
|
||||
def test_cache_report_openai(self):
|
||||
print("=" * 100)
|
||||
# warm up the cache, for the template
|
||||
self.run_openai("Introduce the capital of France.")
|
||||
|
||||
first_cached_tokens_1 = self.run_openai(
|
||||
"How many sparrow do you need to lift a coconut?"
|
||||
).usage.prompt_tokens_details.cached_tokens
|
||||
|
||||
usage_2 = self.run_openai("* sing something about cats").usage
|
||||
first_cached_tokens_2 = usage_2.prompt_tokens_details.cached_tokens
|
||||
# first request may not have 0 cached tokens, but if they only have the template in common they
|
||||
# should be the same once the cache is warmed up
|
||||
assert first_cached_tokens_1 == first_cached_tokens_2
|
||||
|
||||
resp = self.run_openai("* sing something about cats and dogs")
|
||||
print(resp.usage)
|
||||
|
||||
resp = self.run_openai("* sing something about cats, please")
|
||||
print(resp.usage)
|
||||
assert (
|
||||
resp.usage.prompt_tokens_details.cached_tokens
|
||||
>= usage_2.prompt_tokens - self.min_cached
|
||||
)
|
||||
|
||||
def test_cache_report_openai_async(self):
|
||||
print("=" * 100)
|
||||
|
||||
async def run_test():
|
||||
task0 = asyncio.create_task(
|
||||
self.cache_report_openai_async(
|
||||
"first request, to start the inference and let the next two request be started in the same batch"
|
||||
)
|
||||
)
|
||||
await asyncio.sleep(0.05) # to force the first request to be started first
|
||||
task1 = asyncio.create_task(
|
||||
self.cache_report_openai_async(
|
||||
"> can the same batch parallel request use the cache?"
|
||||
)
|
||||
)
|
||||
task2 = asyncio.create_task(
|
||||
self.cache_report_openai_async(
|
||||
"> can the same batch parallel request use the cache?"
|
||||
)
|
||||
)
|
||||
result0, result1, result2 = await asyncio.gather(task0, task1, task2)
|
||||
|
||||
cached_tokens0, prompt_tokens0 = result0
|
||||
cached_tokens1, prompt_tokens1 = result1
|
||||
cached_tokens2, prompt_tokens2 = result2
|
||||
|
||||
print(
|
||||
f"Async request 0 - Cached tokens: {cached_tokens0}, Prompt tokens: {prompt_tokens0}"
|
||||
)
|
||||
print(
|
||||
f"Async request 1 - Cached tokens: {cached_tokens1}, Prompt tokens: {prompt_tokens1}"
|
||||
)
|
||||
print(
|
||||
f"Async request 2 - Cached tokens: {cached_tokens2}, Prompt tokens: {prompt_tokens2}"
|
||||
)
|
||||
|
||||
# Assert that no requests used the cache (becausefirst is alone, and the next two are in the same batch)
|
||||
# If a new optimisation limiting starting request with same prefix at the same time was added
|
||||
# to maximise the cache hit, this would not be true
|
||||
assert cached_tokens1 == cached_tokens2 == cached_tokens0
|
||||
|
||||
asyncio.run(run_test())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
188
test/srt/openai_server/features/test_enable_thinking.py
Normal file
188
test/srt/openai_server/features/test_enable_thinking.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""
|
||||
Usage:
|
||||
python3 -m unittest openai_server.features.test_enable_thinking.TestEnableThinking.test_chat_completion_with_reasoning
|
||||
python3 -m unittest openai_server.features.test_enable_thinking.TestEnableThinking.test_chat_completion_without_reasoning
|
||||
python3 -m unittest openai_server.features.test_enable_thinking.TestEnableThinking.test_stream_chat_completion_with_reasoning
|
||||
python3 -m unittest openai_server.features.test_enable_thinking.TestEnableThinking.test_stream_chat_completion_without_reasoning
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import openai
|
||||
import requests
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_ENABLE_THINKING_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestEnableThinking(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_ENABLE_THINKING_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-1234"
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
other_args=[
|
||||
"--reasoning-parser",
|
||||
"qwen3",
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_chat_completion_with_reasoning(self):
|
||||
# Test non-streaming with "enable_thinking": True, reasoning_content should not be empty
|
||||
client = requests.post(
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
headers={"Authorization": f"Bearer {self.api_key}"},
|
||||
json={
|
||||
"model": self.model,
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"temperature": 0,
|
||||
"separate_reasoning": True,
|
||||
"chat_template_kwargs": {"enable_thinking": True},
|
||||
},
|
||||
)
|
||||
|
||||
self.assertEqual(client.status_code, 200, f"Failed with: {client.text}")
|
||||
data = client.json()
|
||||
|
||||
self.assertIn("choices", data)
|
||||
self.assertTrue(len(data["choices"]) > 0)
|
||||
self.assertIn("message", data["choices"][0])
|
||||
self.assertIn("reasoning_content", data["choices"][0]["message"])
|
||||
self.assertIsNotNone(data["choices"][0]["message"]["reasoning_content"])
|
||||
|
||||
def test_chat_completion_without_reasoning(self):
|
||||
# Test non-streaming with "enable_thinking": False, reasoning_content should be empty
|
||||
client = requests.post(
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
headers={"Authorization": f"Bearer {self.api_key}"},
|
||||
json={
|
||||
"model": self.model,
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"temperature": 0,
|
||||
"separate_reasoning": True,
|
||||
"chat_template_kwargs": {"enable_thinking": False},
|
||||
},
|
||||
)
|
||||
|
||||
self.assertEqual(client.status_code, 200, f"Failed with: {client.text}")
|
||||
data = client.json()
|
||||
|
||||
self.assertIn("choices", data)
|
||||
self.assertTrue(len(data["choices"]) > 0)
|
||||
self.assertIn("message", data["choices"][0])
|
||||
|
||||
if "reasoning_content" in data["choices"][0]["message"]:
|
||||
self.assertIsNone(data["choices"][0]["message"]["reasoning_content"])
|
||||
|
||||
def test_stream_chat_completion_with_reasoning(self):
|
||||
# Test streaming with "enable_thinking": True, reasoning_content should not be empty
|
||||
response = requests.post(
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
headers={"Authorization": f"Bearer {self.api_key}"},
|
||||
json={
|
||||
"model": self.model,
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"temperature": 0,
|
||||
"separate_reasoning": True,
|
||||
"stream": True,
|
||||
"chat_template_kwargs": {"enable_thinking": True},
|
||||
},
|
||||
stream=True,
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 200, f"Failed with: {response.text}")
|
||||
|
||||
has_reasoning = False
|
||||
has_content = False
|
||||
|
||||
print("\n=== Stream With Reasoning ===")
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
line = line.decode("utf-8")
|
||||
if line.startswith("data:") and not line.startswith("data: [DONE]"):
|
||||
data = json.loads(line[6:])
|
||||
if "choices" in data and len(data["choices"]) > 0:
|
||||
delta = data["choices"][0].get("delta", {})
|
||||
|
||||
if "reasoning_content" in delta and delta["reasoning_content"]:
|
||||
has_reasoning = True
|
||||
|
||||
if "content" in delta and delta["content"]:
|
||||
has_content = True
|
||||
|
||||
self.assertTrue(
|
||||
has_reasoning,
|
||||
"The reasoning content is not included in the stream response",
|
||||
)
|
||||
self.assertTrue(
|
||||
has_content, "The stream response does not contain normal content"
|
||||
)
|
||||
|
||||
def test_stream_chat_completion_without_reasoning(self):
|
||||
# Test streaming with "enable_thinking": False, reasoning_content should be empty
|
||||
response = requests.post(
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
headers={"Authorization": f"Bearer {self.api_key}"},
|
||||
json={
|
||||
"model": self.model,
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"temperature": 0,
|
||||
"separate_reasoning": True,
|
||||
"stream": True,
|
||||
"chat_template_kwargs": {"enable_thinking": False},
|
||||
},
|
||||
stream=True,
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 200, f"Failed with: {response.text}")
|
||||
|
||||
has_reasoning = False
|
||||
has_content = False
|
||||
|
||||
print("\n=== Stream Without Reasoning ===")
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
line = line.decode("utf-8")
|
||||
if line.startswith("data:") and not line.startswith("data: [DONE]"):
|
||||
data = json.loads(line[6:])
|
||||
if "choices" in data and len(data["choices"]) > 0:
|
||||
delta = data["choices"][0].get("delta", {})
|
||||
|
||||
if "reasoning_content" in delta and delta["reasoning_content"]:
|
||||
has_reasoning = True
|
||||
|
||||
if "content" in delta and delta["content"]:
|
||||
has_content = True
|
||||
|
||||
self.assertFalse(
|
||||
has_reasoning,
|
||||
"The reasoning content should not be included in the stream response",
|
||||
)
|
||||
self.assertTrue(
|
||||
has_content, "The stream response does not contain normal content"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
153
test/srt/openai_server/features/test_json_constrained.py
Normal file
153
test/srt/openai_server/features/test_json_constrained.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""
|
||||
python3 -m unittest openai_server.features.test_json_constrained.TestJSONConstrainedOutlinesBackend.test_json_generate
|
||||
python3 -m unittest openai_server.features.test_json_constrained.TestJSONConstrainedXGrammarBackend.test_json_generate
|
||||
python3 -m unittest openai_server.features.test_json_constrained.TestJSONConstrainedLLGuidanceBackend.test_json_generate
|
||||
"""
|
||||
|
||||
import json
|
||||
import unittest
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import openai
|
||||
import requests
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
def setup_class(cls, backend: str):
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.json_schema = json.dumps(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "pattern": "^[\\w]+$"},
|
||||
"population": {"type": "integer"},
|
||||
},
|
||||
"required": ["name", "population"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
)
|
||||
|
||||
other_args = [
|
||||
"--max-running-requests",
|
||||
"10",
|
||||
"--grammar-backend",
|
||||
backend,
|
||||
]
|
||||
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=other_args,
|
||||
)
|
||||
|
||||
|
||||
class TestJSONConstrainedOutlinesBackend(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
setup_class(cls, backend="outlines")
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def run_decode(self, json_schema, return_logprob=False, top_logprobs_num=0, n=1):
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"text": "The capital of France is",
|
||||
"sampling_params": {
|
||||
"temperature": 0 if n == 1 else 0.5,
|
||||
"max_new_tokens": 128,
|
||||
"n": n,
|
||||
"stop_token_ids": [119690],
|
||||
"json_schema": json_schema,
|
||||
},
|
||||
"stream": False,
|
||||
"return_logprob": return_logprob,
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
"logprob_start_len": 0,
|
||||
},
|
||||
)
|
||||
ret = response.json()
|
||||
print(json.dumps(ret))
|
||||
print("=" * 100)
|
||||
|
||||
if not json_schema or json_schema == "INVALID":
|
||||
return
|
||||
|
||||
# Make sure the json output is valid
|
||||
try:
|
||||
js_obj = json.loads(ret["text"])
|
||||
except (TypeError, json.decoder.JSONDecodeError):
|
||||
raise
|
||||
|
||||
self.assertIsInstance(js_obj["name"], str)
|
||||
self.assertIsInstance(js_obj["population"], int)
|
||||
|
||||
def test_json_generate(self):
|
||||
self.run_decode(json_schema=self.json_schema)
|
||||
|
||||
def test_json_invalid(self):
|
||||
self.run_decode(json_schema="INVALID")
|
||||
|
||||
def test_json_openai(self):
|
||||
client = openai.Client(api_key="EMPTY", base_url=f"{self.base_url}/v1")
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Introduce the capital of France. Return in a JSON format.",
|
||||
},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=128,
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": {"name": "foo", "schema": json.loads(self.json_schema)},
|
||||
},
|
||||
)
|
||||
text = response.choices[0].message.content
|
||||
|
||||
try:
|
||||
js_obj = json.loads(text)
|
||||
except (TypeError, json.decoder.JSONDecodeError):
|
||||
print("JSONDecodeError", text)
|
||||
raise
|
||||
|
||||
self.assertIsInstance(js_obj["name"], str)
|
||||
self.assertIsInstance(js_obj["population"], int)
|
||||
|
||||
def test_mix_json_and_other(self):
|
||||
json_schemas = [None, None, self.json_schema, self.json_schema] * 10
|
||||
|
||||
with ThreadPoolExecutor(len(json_schemas)) as executor:
|
||||
list(executor.map(self.run_decode, json_schemas))
|
||||
|
||||
|
||||
class TestJSONConstrainedXGrammarBackend(TestJSONConstrainedOutlinesBackend):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
setup_class(cls, backend="xgrammar")
|
||||
|
||||
|
||||
class TestJSONConstrainedLLGuidanceBackend(TestJSONConstrainedOutlinesBackend):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
setup_class(cls, backend="llguidance")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
137
test/srt/openai_server/features/test_json_mode.py
Normal file
137
test/srt/openai_server/features/test_json_mode.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""
|
||||
python3 -m unittest openai_server.features.test_json_mode.TestJSONModeOutlines.test_json_mode_response
|
||||
python3 -m unittest openai_server.features.test_json_mode.TestJSONModeOutlines.test_json_mode_with_streaming
|
||||
|
||||
python3 -m unittest openai_server.features.test_json_mode.TestJSONModeXGrammar.test_json_mode_response
|
||||
python3 -m unittest openai_server.features.test_json_mode.TestJSONModeXGrammar.test_json_mode_with_streaming
|
||||
|
||||
python3 -m unittest openai_server.features.test_json_mode.TestJSONModeLLGuidance.test_json_mode_response
|
||||
python3 -m unittest openai_server.features.test_json_mode.TestJSONModeLLGuidance.test_json_mode_with_streaming
|
||||
"""
|
||||
|
||||
import json
|
||||
import unittest
|
||||
|
||||
import openai
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
def setup_class(cls, backend):
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
|
||||
other_args = [
|
||||
"--max-running-requests",
|
||||
"10",
|
||||
"--grammar-backend",
|
||||
backend,
|
||||
]
|
||||
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=other_args,
|
||||
)
|
||||
cls.client = openai.Client(api_key="EMPTY", base_url=f"{cls.base_url}/v1")
|
||||
|
||||
|
||||
class TestJSONModeOutlines(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
setup_class(cls, "outlines")
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_json_mode_response(self):
|
||||
"""Test that response_format json_object (also known as "json mode") produces valid JSON, even without a system prompt that mentions JSON."""
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
# We are deliberately omitting "That produces JSON" or similar phrases from the assistant prompt so that we don't have misleading test results
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful AI assistant that gives a short answer.",
|
||||
},
|
||||
{"role": "user", "content": "What is the capital of Bulgaria?"},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=128,
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
text = response.choices[0].message.content
|
||||
|
||||
print(f"Response ({len(text)} characters): {text}")
|
||||
|
||||
# Verify the response is valid JSON
|
||||
try:
|
||||
js_obj = json.loads(text)
|
||||
except json.JSONDecodeError as e:
|
||||
self.fail(f"Response is not valid JSON. Error: {e}. Response: {text}")
|
||||
|
||||
# Verify it's actually an object (dict)
|
||||
self.assertIsInstance(js_obj, dict, f"Response is not a JSON object: {text}")
|
||||
|
||||
def test_json_mode_with_streaming(self):
|
||||
"""Test that streaming with json_object response (also known as "json mode") format works correctly, even without a system prompt that mentions JSON."""
|
||||
stream = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
# We are deliberately omitting "That produces JSON" or similar phrases from the assistant prompt so that we don't have misleading test results
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful AI assistant that gives a short answer.",
|
||||
},
|
||||
{"role": "user", "content": "What is the capital of Bulgaria?"},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=128,
|
||||
response_format={"type": "json_object"},
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Collect all chunks
|
||||
chunks = []
|
||||
for chunk in stream:
|
||||
if chunk.choices[0].delta.content is not None:
|
||||
chunks.append(chunk.choices[0].delta.content)
|
||||
full_response = "".join(chunks)
|
||||
|
||||
print(
|
||||
f"Concatenated Response ({len(full_response)} characters): {full_response}"
|
||||
)
|
||||
|
||||
# Verify the combined response is valid JSON
|
||||
try:
|
||||
js_obj = json.loads(full_response)
|
||||
except json.JSONDecodeError as e:
|
||||
self.fail(
|
||||
f"Streamed response is not valid JSON. Error: {e}. Response: {full_response}"
|
||||
)
|
||||
|
||||
self.assertIsInstance(js_obj, dict)
|
||||
|
||||
|
||||
class TestJSONModeXGrammar(TestJSONModeOutlines):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
setup_class(cls, backend="xgrammar")
|
||||
|
||||
|
||||
class TestJSONModeLLGuidance(TestJSONModeOutlines):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
setup_class(cls, backend="llguidance")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
98
test/srt/openai_server/features/test_openai_server_ebnf.py
Normal file
98
test/srt/openai_server/features/test_openai_server_ebnf.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import re
|
||||
|
||||
import openai
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# EBNF Test Class: TestOpenAIServerEBNF
|
||||
# Launches the server with xgrammar, has only EBNF tests
|
||||
# -------------------------------------------------------------------------
|
||||
class TestOpenAIServerEBNF(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
|
||||
# passing xgrammar specifically
|
||||
other_args = ["--grammar-backend", "xgrammar"]
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
other_args=other_args,
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_ebnf(self):
|
||||
"""
|
||||
Ensure we can pass `ebnf` to the local openai server
|
||||
and that it enforces the grammar.
|
||||
"""
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
ebnf_grammar = r"""
|
||||
root ::= "Hello" | "Hi" | "Hey"
|
||||
"""
|
||||
pattern = re.compile(r"^(Hello|Hi|Hey)[.!?]*\s*$")
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful EBNF test bot."},
|
||||
{"role": "user", "content": "Say a greeting (Hello, Hi, or Hey)."},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=32,
|
||||
extra_body={"ebnf": ebnf_grammar},
|
||||
)
|
||||
text = response.choices[0].message.content.strip()
|
||||
self.assertTrue(len(text) > 0, "Got empty text from EBNF generation")
|
||||
self.assertRegex(text, pattern, f"Text '{text}' doesn't match EBNF choices")
|
||||
|
||||
def test_ebnf_strict_json(self):
|
||||
"""
|
||||
A stricter EBNF that produces exactly {"name":"Alice"} format
|
||||
with no trailing punctuation or extra fields.
|
||||
"""
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
ebnf_grammar = r"""
|
||||
root ::= "{" pair "}"
|
||||
pair ::= "\"name\"" ":" string
|
||||
string ::= "\"" [A-Za-z]+ "\""
|
||||
"""
|
||||
pattern = re.compile(r'^\{"name":"[A-Za-z]+"\}$')
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "EBNF mini-JSON generator."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Generate single key JSON with only letters.",
|
||||
},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=64,
|
||||
extra_body={"ebnf": ebnf_grammar},
|
||||
)
|
||||
text = response.choices[0].message.content.strip()
|
||||
self.assertTrue(len(text) > 0, "Got empty text from EBNF strict JSON test")
|
||||
self.assertRegex(
|
||||
text, pattern, f"Text '{text}' not matching the EBNF strict JSON shape"
|
||||
)
|
||||
@@ -0,0 +1,356 @@
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import unittest
|
||||
from abc import ABC
|
||||
|
||||
import numpy as np
|
||||
import openai
|
||||
import torch
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
|
||||
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
|
||||
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class BaseTestOpenAIServerWithHiddenStates(ABC):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.return_hidden_states = [False, True]
|
||||
cls.use_list_input = [True, False]
|
||||
cls.parallel_sample_nums = [1, 2]
|
||||
|
||||
def test_completion(self):
|
||||
for return_hidden_states in self.return_hidden_states:
|
||||
for use_list_input in self.use_list_input:
|
||||
for parallel_sample_num in self.parallel_sample_nums:
|
||||
self.run_completion(
|
||||
use_list_input,
|
||||
parallel_sample_num,
|
||||
return_hidden_states,
|
||||
)
|
||||
|
||||
def test_completion_stream(self):
|
||||
# parallel sampling and list input are not supported in streaming mode
|
||||
for return_hidden_states in self.return_hidden_states:
|
||||
for use_list_input in self.use_list_input:
|
||||
for parallel_sample_num in self.parallel_sample_nums:
|
||||
self.run_completion_stream(
|
||||
use_list_input,
|
||||
parallel_sample_num,
|
||||
return_hidden_states,
|
||||
)
|
||||
|
||||
def test_chat_completion(self):
|
||||
for return_hidden_states in self.return_hidden_states:
|
||||
for (
|
||||
parallel_sample_num
|
||||
) in (
|
||||
self.parallel_sample_nums
|
||||
): # parallel sample num 2 breaks in the adapter with a 400 for EAGLE
|
||||
self.run_chat_completion(parallel_sample_num, return_hidden_states)
|
||||
|
||||
def test_chat_completion_stream(self):
|
||||
for return_hidden_states in self.return_hidden_states:
|
||||
for (
|
||||
parallel_sample_num
|
||||
) in (
|
||||
self.parallel_sample_nums
|
||||
): # parallel sample num > 1 breaks in the adapter with a 400 for EAGLE
|
||||
self.run_chat_completion_stream(
|
||||
parallel_sample_num, return_hidden_states
|
||||
)
|
||||
|
||||
def run_completion(
|
||||
self,
|
||||
use_list_input,
|
||||
parallel_sample_num,
|
||||
return_hidden_states,
|
||||
):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
prompt = "The capital of France is"
|
||||
prompt_input = prompt
|
||||
|
||||
if use_list_input:
|
||||
prompt_arg = [prompt_input, prompt_input]
|
||||
num_choices = len(prompt_arg)
|
||||
else:
|
||||
prompt_arg = prompt_input
|
||||
num_choices = 1
|
||||
|
||||
response = client.completions.create(
|
||||
model=self.model,
|
||||
prompt=prompt_arg,
|
||||
temperature=0,
|
||||
max_tokens=32,
|
||||
n=parallel_sample_num,
|
||||
extra_body=dict(return_hidden_states=return_hidden_states),
|
||||
)
|
||||
|
||||
for choice in response.choices:
|
||||
assert hasattr(choice, "hidden_states") == return_hidden_states
|
||||
if return_hidden_states:
|
||||
assert choice.hidden_states is not None, "hidden_states was None"
|
||||
|
||||
def run_completion_stream(
|
||||
self,
|
||||
use_list_input,
|
||||
parallel_sample_num,
|
||||
return_hidden_states,
|
||||
):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
prompt = "The capital of France is"
|
||||
prompt_input = prompt
|
||||
num_prompt_tokens = len(self.tokenizer.encode(prompt))
|
||||
|
||||
if use_list_input:
|
||||
prompt_arg = [prompt_input, prompt_input]
|
||||
num_choices = len(prompt_arg)
|
||||
num_prompt_tokens *= 2
|
||||
else:
|
||||
prompt_arg = prompt_input
|
||||
num_choices = 1
|
||||
|
||||
generator = client.completions.create(
|
||||
model=self.model,
|
||||
prompt=prompt_arg,
|
||||
temperature=0,
|
||||
max_tokens=32,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
n=parallel_sample_num,
|
||||
extra_body=dict(return_hidden_states=return_hidden_states),
|
||||
)
|
||||
|
||||
hidden_states_list = []
|
||||
for response in generator:
|
||||
usage = response.usage
|
||||
for choice in response.choices:
|
||||
if hasattr(choice, "hidden_states"):
|
||||
assert return_hidden_states
|
||||
assert choice.hidden_states is not None
|
||||
hidden_states_list.append(choice.hidden_states)
|
||||
|
||||
if return_hidden_states:
|
||||
assert (
|
||||
len(hidden_states_list) == parallel_sample_num * num_choices
|
||||
), f"Expected {parallel_sample_num * num_choices} hidden states, got {len(hidden_states_list)}"
|
||||
else:
|
||||
assert (
|
||||
hidden_states_list == []
|
||||
), "hidden_states were returned and should not have been"
|
||||
|
||||
def run_chat_completion(self, parallel_sample_num, return_hidden_states):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
response = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the capital of France? Answer in a few words.",
|
||||
},
|
||||
],
|
||||
temperature=0,
|
||||
n=parallel_sample_num,
|
||||
extra_body=dict(return_hidden_states=return_hidden_states),
|
||||
)
|
||||
|
||||
for choice in response.choices:
|
||||
assert hasattr(choice, "hidden_states") == return_hidden_states
|
||||
if return_hidden_states:
|
||||
assert choice.hidden_states is not None, "hidden_states was None"
|
||||
|
||||
def run_chat_completion_stream(
|
||||
self, parallel_sample_num=1, return_hidden_states=False
|
||||
):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
generator = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{"role": "user", "content": "What is the capital of France?"},
|
||||
],
|
||||
temperature=0,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
n=parallel_sample_num,
|
||||
extra_body=dict(return_hidden_states=return_hidden_states),
|
||||
)
|
||||
|
||||
is_firsts = {}
|
||||
hidden_states_list = []
|
||||
|
||||
for response in generator:
|
||||
for choice in response.choices:
|
||||
if hasattr(choice.delta, "hidden_states"):
|
||||
assert return_hidden_states
|
||||
assert choice.delta.hidden_states is not None
|
||||
hidden_states_list.append(choice.delta.hidden_states)
|
||||
|
||||
if return_hidden_states:
|
||||
assert (
|
||||
len(hidden_states_list) == parallel_sample_num
|
||||
), f"Expected {parallel_sample_num} hidden states, got {len(hidden_states_list)}"
|
||||
else:
|
||||
assert (
|
||||
hidden_states_list == []
|
||||
), "hidden_states were returned and should not have been"
|
||||
|
||||
|
||||
class TestOpenAIServerWithHiddenStatesEnabled(
|
||||
CustomTestCase, BaseTestOpenAIServerWithHiddenStates
|
||||
):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
other_args=["--enable-return-hidden-states"],
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
||||
cls.return_hidden_states = [False, True]
|
||||
cls.use_list_input = [True, False]
|
||||
cls.parallel_sample_nums = [1, 2]
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
|
||||
class TestOpenAIServerWithHiddenStatesEnabledAndCUDAGraphDisabled(
|
||||
CustomTestCase, BaseTestOpenAIServerWithHiddenStates
|
||||
):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
other_args=["--enable-return-hidden-states", "--disable-cuda-graph"],
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
||||
cls.return_hidden_states = [False, True]
|
||||
cls.use_list_input = [True, False]
|
||||
cls.parallel_sample_nums = [1]
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
|
||||
class TestOpenAIServerWithEAGLEAndHiddenStatesEnabled(
|
||||
CustomTestCase, BaseTestOpenAIServerWithHiddenStates
|
||||
):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
cls.speculative_draft_model = DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
|
||||
cls.speculative_algorithm = "EAGLE"
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--speculative-algorithm",
|
||||
"EAGLE",
|
||||
"--speculative-draft-model-path",
|
||||
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
|
||||
"--speculative-num-steps",
|
||||
5,
|
||||
"--speculative-eagle-topk",
|
||||
8,
|
||||
"--speculative-num-draft-tokens",
|
||||
64,
|
||||
"--mem-fraction-static",
|
||||
0.7,
|
||||
"--chunked-prefill-size",
|
||||
128,
|
||||
"--max-running-requests",
|
||||
8,
|
||||
"--enable-return-hidden-states",
|
||||
],
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
cls.tokenizer = get_tokenizer(DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST)
|
||||
cls.return_hidden_states = [False, True]
|
||||
cls.use_list_input = [True, False]
|
||||
cls.parallel_sample_nums = [1]
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
|
||||
class TestOpenAIServerWithEAGLE3AndHiddenStatesEnabled(
|
||||
CustomTestCase, BaseTestOpenAIServerWithHiddenStates
|
||||
):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
cls.speculative_algorithm = "EAGLE3"
|
||||
cls.speculative_draft_model = "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B"
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--speculative-algorithm",
|
||||
cls.speculative_algorithm,
|
||||
"--speculative-draft-model-path",
|
||||
cls.speculative_draft_model,
|
||||
"--speculative-num-steps",
|
||||
5,
|
||||
"--speculative-eagle-topk",
|
||||
16,
|
||||
"--speculative-num-draft-tokens",
|
||||
64,
|
||||
"--mem-fraction-static",
|
||||
0.7,
|
||||
"--chunked-prefill-size",
|
||||
128,
|
||||
"--max-running-requests",
|
||||
8,
|
||||
"--dtype",
|
||||
"float16",
|
||||
"--enable-return-hidden-states",
|
||||
],
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
cls.tokenizer = get_tokenizer(cls.model)
|
||||
cls.return_hidden_states = [False, True]
|
||||
cls.use_list_input = [True, False]
|
||||
cls.parallel_sample_nums = [1]
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
343
test/srt/openai_server/features/test_reasoning_content.py
Normal file
343
test/srt/openai_server/features/test_reasoning_content.py
Normal file
@@ -0,0 +1,343 @@
|
||||
"""
|
||||
Usage:
|
||||
python3 -m unittest openai_server.features.test_reasoning_content.TestReasoningContentAPI.test_streaming_separate_reasoning_false
|
||||
python3 -m unittest openai_server.features.test_reasoning_content.TestReasoningContentAPI.test_streaming_separate_reasoning_true
|
||||
python3 -m unittest openai_server.features.test_reasoning_content.TestReasoningContentAPI.test_streaming_separate_reasoning_true_stream_reasoning_false
|
||||
python3 -m unittest openai_server.features.test_reasoning_content.TestReasoningContentAPI.test_nonstreaming_separate_reasoning_false
|
||||
python3 -m unittest openai_server.features.test_reasoning_content.TestReasoningContentAPI.test_nonstreaming_separate_reasoning_true
|
||||
python3 -m unittest openai_server.features.test_reasoning_content.TestReasoningContentStartup.test_nonstreaming
|
||||
python3 -m unittest openai_server.features.test_reasoning_content.TestReasoningContentStartup.test_streaming
|
||||
"""
|
||||
|
||||
import json
|
||||
import unittest
|
||||
|
||||
import openai
|
||||
import requests
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_REASONING_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestReasoningContentAPI(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_REASONING_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-1234"
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
other_args=[
|
||||
"--reasoning-parser",
|
||||
"deepseek-r1",
|
||||
],
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_streaming_separate_reasoning_false(self):
|
||||
# Test streaming with separate_reasoning=False, reasoning_content should be empty
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is 1+3?",
|
||||
}
|
||||
],
|
||||
"max_tokens": 100,
|
||||
"stream": True,
|
||||
"extra_body": {"separate_reasoning": False},
|
||||
}
|
||||
response = client.chat.completions.create(**payload)
|
||||
|
||||
reasoning_content = ""
|
||||
content = ""
|
||||
for chunk in response:
|
||||
if chunk.choices[0].delta.content:
|
||||
content += chunk.choices[0].delta.content
|
||||
elif chunk.choices[0].delta.reasoning_content:
|
||||
reasoning_content += chunk.choices[0].delta.reasoning_content
|
||||
|
||||
assert len(reasoning_content) == 0
|
||||
assert len(content) > 0
|
||||
|
||||
def test_streaming_separate_reasoning_true(self):
|
||||
# Test streaming with separate_reasoning=True, reasoning_content should not be empty
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is 1+3?",
|
||||
}
|
||||
],
|
||||
"max_tokens": 100,
|
||||
"stream": True,
|
||||
"extra_body": {"separate_reasoning": True},
|
||||
}
|
||||
response = client.chat.completions.create(**payload)
|
||||
|
||||
reasoning_content = ""
|
||||
content = ""
|
||||
for chunk in response:
|
||||
if chunk.choices[0].delta.content:
|
||||
content += chunk.choices[0].delta.content
|
||||
elif chunk.choices[0].delta.reasoning_content:
|
||||
reasoning_content += chunk.choices[0].delta.reasoning_content
|
||||
|
||||
assert len(reasoning_content) > 0
|
||||
assert len(content) > 0
|
||||
|
||||
def test_streaming_separate_reasoning_true_stream_reasoning_false(self):
|
||||
# Test streaming with separate_reasoning=True, reasoning_content should not be empty
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is 1+3?",
|
||||
}
|
||||
],
|
||||
"max_tokens": 100,
|
||||
"stream": True,
|
||||
"extra_body": {"separate_reasoning": True, "stream_reasoning": False},
|
||||
}
|
||||
response = client.chat.completions.create(**payload)
|
||||
|
||||
reasoning_content = ""
|
||||
content = ""
|
||||
first_chunk = False
|
||||
for chunk in response:
|
||||
if chunk.choices[0].delta.reasoning_content:
|
||||
reasoning_content = chunk.choices[0].delta.reasoning_content
|
||||
first_chunk = True
|
||||
if chunk.choices[0].delta.content:
|
||||
content += chunk.choices[0].delta.content
|
||||
if not first_chunk:
|
||||
reasoning_content = chunk.choices[0].delta.reasoning_content
|
||||
first_chunk = True
|
||||
if not first_chunk:
|
||||
assert (
|
||||
not chunk.choices[0].delta.reasoning_content
|
||||
or len(chunk.choices[0].delta.reasoning_content) == 0
|
||||
)
|
||||
assert len(reasoning_content) > 0
|
||||
assert len(content) > 0
|
||||
|
||||
def test_nonstreaming_separate_reasoning_false(self):
|
||||
# Test non-streaming with separate_reasoning=False, reasoning_content should be empty
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is 1+3?",
|
||||
}
|
||||
],
|
||||
"max_tokens": 100,
|
||||
"extra_body": {"separate_reasoning": False},
|
||||
}
|
||||
response = client.chat.completions.create(**payload)
|
||||
|
||||
assert (
|
||||
not response.choices[0].message.reasoning_content
|
||||
or len(response.choices[0].message.reasoning_content) == 0
|
||||
)
|
||||
assert len(response.choices[0].message.content) > 0
|
||||
|
||||
def test_nonstreaming_separate_reasoning_true(self):
|
||||
# Test non-streaming with separate_reasoning=True, reasoning_content should not be empty
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is 1+3?",
|
||||
}
|
||||
],
|
||||
"max_tokens": 100,
|
||||
"extra_body": {"separate_reasoning": True},
|
||||
}
|
||||
response = client.chat.completions.create(**payload)
|
||||
|
||||
assert len(response.choices[0].message.reasoning_content) > 0
|
||||
assert len(response.choices[0].message.content) > 0
|
||||
|
||||
|
||||
class TestReasoningContentWithoutParser(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_REASONING_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-1234"
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
other_args=[], # No reasoning parser
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_streaming_separate_reasoning_false(self):
|
||||
# Test streaming with separate_reasoning=False, reasoning_content should be empty
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is 1+3?",
|
||||
}
|
||||
],
|
||||
"max_tokens": 100,
|
||||
"stream": True,
|
||||
"extra_body": {"separate_reasoning": False},
|
||||
}
|
||||
response = client.chat.completions.create(**payload)
|
||||
|
||||
reasoning_content = ""
|
||||
content = ""
|
||||
for chunk in response:
|
||||
if chunk.choices[0].delta.content:
|
||||
content += chunk.choices[0].delta.content
|
||||
elif chunk.choices[0].delta.reasoning_content:
|
||||
reasoning_content += chunk.choices[0].delta.reasoning_content
|
||||
|
||||
assert len(reasoning_content) == 0
|
||||
assert len(content) > 0
|
||||
|
||||
def test_streaming_separate_reasoning_true(self):
|
||||
# Test streaming with separate_reasoning=True, reasoning_content should not be empty
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is 1+3?",
|
||||
}
|
||||
],
|
||||
"max_tokens": 100,
|
||||
"stream": True,
|
||||
"extra_body": {"separate_reasoning": True},
|
||||
}
|
||||
response = client.chat.completions.create(**payload)
|
||||
|
||||
reasoning_content = ""
|
||||
content = ""
|
||||
for chunk in response:
|
||||
if chunk.choices[0].delta.content:
|
||||
content += chunk.choices[0].delta.content
|
||||
elif chunk.choices[0].delta.reasoning_content:
|
||||
reasoning_content += chunk.choices[0].delta.reasoning_content
|
||||
|
||||
assert len(reasoning_content) == 0
|
||||
assert len(content) > 0
|
||||
|
||||
def test_streaming_separate_reasoning_true_stream_reasoning_false(self):
|
||||
# Test streaming with separate_reasoning=True, reasoning_content should not be empty
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is 1+3?",
|
||||
}
|
||||
],
|
||||
"max_tokens": 100,
|
||||
"stream": True,
|
||||
"extra_body": {"separate_reasoning": True, "stream_reasoning": False},
|
||||
}
|
||||
response = client.chat.completions.create(**payload)
|
||||
|
||||
reasoning_content = ""
|
||||
content = ""
|
||||
first_chunk = False
|
||||
for chunk in response:
|
||||
if chunk.choices[0].delta.reasoning_content:
|
||||
reasoning_content = chunk.choices[0].delta.reasoning_content
|
||||
first_chunk = True
|
||||
if chunk.choices[0].delta.content:
|
||||
content += chunk.choices[0].delta.content
|
||||
if not first_chunk:
|
||||
reasoning_content = chunk.choices[0].delta.reasoning_content
|
||||
first_chunk = True
|
||||
if not first_chunk:
|
||||
assert (
|
||||
not chunk.choices[0].delta.reasoning_content
|
||||
or len(chunk.choices[0].delta.reasoning_content) == 0
|
||||
)
|
||||
assert not reasoning_content or len(reasoning_content) == 0
|
||||
assert len(content) > 0
|
||||
|
||||
def test_nonstreaming_separate_reasoning_false(self):
|
||||
# Test non-streaming with separate_reasoning=False, reasoning_content should be empty
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is 1+3?",
|
||||
}
|
||||
],
|
||||
"max_tokens": 100,
|
||||
"extra_body": {"separate_reasoning": False},
|
||||
}
|
||||
response = client.chat.completions.create(**payload)
|
||||
|
||||
assert (
|
||||
not response.choices[0].message.reasoning_content
|
||||
or len(response.choices[0].message.reasoning_content) == 0
|
||||
)
|
||||
assert len(response.choices[0].message.content) > 0
|
||||
|
||||
def test_nonstreaming_separate_reasoning_true(self):
|
||||
# Test non-streaming with separate_reasoning=True, reasoning_content should not be empty
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is 1+3?",
|
||||
}
|
||||
],
|
||||
"max_tokens": 100,
|
||||
"extra_body": {"separate_reasoning": True},
|
||||
}
|
||||
response = client.chat.completions.create(**payload)
|
||||
|
||||
assert (
|
||||
not response.choices[0].message.reasoning_content
|
||||
or len(response.choices[0].message.reasoning_content) == 0
|
||||
)
|
||||
assert len(response.choices[0].message.content) > 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user