refactor(test): reorganize OpenAI test file structure (#7408)
This commit is contained in:
0
test/srt/openai_server/__init__.py
Normal file
0
test/srt/openai_server/__init__.py
Normal file
0
test/srt/openai_server/basic/__init__.py
Normal file
0
test/srt/openai_server/basic/__init__.py
Normal file
97
test/srt/openai_server/basic/test_openai_embedding.py
Normal file
97
test/srt/openai_server/basic/test_openai_embedding.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import unittest
|
||||
|
||||
import openai
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestOpenAIEmbedding(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
|
||||
# Configure embedding-specific args
|
||||
other_args = ["--is-embedding", "--enable-metrics"]
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
other_args=other_args,
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_embedding_single(self):
|
||||
"""Test single embedding request"""
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
response = client.embeddings.create(model=self.model, input="Hello world")
|
||||
self.assertEqual(len(response.data), 1)
|
||||
self.assertTrue(len(response.data[0].embedding) > 0)
|
||||
|
||||
def test_embedding_batch(self):
|
||||
"""Test batch embedding request"""
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
response = client.embeddings.create(
|
||||
model=self.model, input=["Hello world", "Test text"]
|
||||
)
|
||||
self.assertEqual(len(response.data), 2)
|
||||
self.assertTrue(len(response.data[0].embedding) > 0)
|
||||
self.assertTrue(len(response.data[1].embedding) > 0)
|
||||
|
||||
def test_embedding_single_batch_str(self):
|
||||
"""Test embedding with a List[str] and length equals to 1"""
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
response = client.embeddings.create(model=self.model, input=["Hello world"])
|
||||
self.assertEqual(len(response.data), 1)
|
||||
self.assertTrue(len(response.data[0].embedding) > 0)
|
||||
|
||||
def test_embedding_single_int_list(self):
|
||||
"""Test embedding with a List[int] or List[List[int]]]"""
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
response = client.embeddings.create(
|
||||
model=self.model,
|
||||
input=[[15339, 314, 703, 284, 612, 262, 10658, 10188, 286, 2061]],
|
||||
)
|
||||
self.assertEqual(len(response.data), 1)
|
||||
self.assertTrue(len(response.data[0].embedding) > 0)
|
||||
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
response = client.embeddings.create(
|
||||
model=self.model,
|
||||
input=[15339, 314, 703, 284, 612, 262, 10658, 10188, 286, 2061],
|
||||
)
|
||||
self.assertEqual(len(response.data), 1)
|
||||
self.assertTrue(len(response.data[0].embedding) > 0)
|
||||
|
||||
def test_empty_string_embedding(self):
|
||||
"""Test embedding an empty string."""
|
||||
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
# Text embedding example with empty string
|
||||
text = ""
|
||||
# Expect a BadRequestError for empty input
|
||||
with self.assertRaises(openai.BadRequestError) as cm:
|
||||
client.embeddings.create(
|
||||
model=self.model,
|
||||
input=text,
|
||||
)
|
||||
# check the status code
|
||||
self.assertEqual(cm.exception.status_code, 400)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
743
test/srt/openai_server/basic/test_openai_server.py
Normal file
743
test/srt/openai_server/basic/test_openai_server.py
Normal file
@@ -0,0 +1,743 @@
|
||||
"""
|
||||
python3 -m unittest openai_server.basic.test_openai_server.TestOpenAIServer.test_completion
|
||||
python3 -m unittest openai_server.basic.test_openai_server.TestOpenAIServer.test_completion_stream
|
||||
python3 -m unittest openai_server.basic.test_openai_server.TestOpenAIServer.test_chat_completion
|
||||
python3 -m unittest openai_server.basic.test_openai_server.TestOpenAIServer.test_chat_completion_stream
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import openai
|
||||
import requests
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.runners import TEST_RERANK_QUERY_DOCS
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_SMALL_CROSS_ENCODER_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestOpenAIServer(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def run_completion(
|
||||
self, echo, logprobs, use_list_input, parallel_sample_num, token_input
|
||||
):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
prompt = "The capital of France is"
|
||||
if token_input:
|
||||
prompt_input = self.tokenizer.encode(prompt)
|
||||
num_prompt_tokens = len(prompt_input)
|
||||
else:
|
||||
prompt_input = prompt
|
||||
num_prompt_tokens = len(self.tokenizer.encode(prompt))
|
||||
|
||||
if use_list_input:
|
||||
prompt_arg = [prompt_input, prompt_input]
|
||||
num_choices = len(prompt_arg)
|
||||
num_prompt_tokens *= 2
|
||||
else:
|
||||
prompt_arg = prompt_input
|
||||
num_choices = 1
|
||||
|
||||
response = client.completions.create(
|
||||
model=self.model,
|
||||
prompt=prompt_arg,
|
||||
temperature=0,
|
||||
max_tokens=32,
|
||||
echo=echo,
|
||||
logprobs=logprobs,
|
||||
n=parallel_sample_num,
|
||||
)
|
||||
|
||||
assert len(response.choices) == num_choices * parallel_sample_num
|
||||
|
||||
if echo:
|
||||
text = response.choices[0].text
|
||||
assert text.startswith(prompt)
|
||||
|
||||
if logprobs:
|
||||
assert response.choices[0].logprobs
|
||||
assert isinstance(response.choices[0].logprobs.tokens[0], str)
|
||||
assert isinstance(response.choices[0].logprobs.top_logprobs[1], dict)
|
||||
ret_num_top_logprobs = len(response.choices[0].logprobs.top_logprobs[1])
|
||||
|
||||
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map
|
||||
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
|
||||
assert ret_num_top_logprobs > 0
|
||||
|
||||
# when echo=True and request.logprobs>0, logprob_start_len is 0, so the first token's logprob would be None.
|
||||
if not echo:
|
||||
assert response.choices[0].logprobs.token_logprobs[0]
|
||||
|
||||
assert response.id
|
||||
assert response.created
|
||||
assert (
|
||||
response.usage.prompt_tokens == num_prompt_tokens
|
||||
), f"{response.usage.prompt_tokens} vs {num_prompt_tokens}"
|
||||
assert response.usage.completion_tokens > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
def run_completion_stream(
|
||||
self, echo, logprobs, use_list_input, parallel_sample_num, token_input
|
||||
):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
prompt = "The capital of France is"
|
||||
if token_input:
|
||||
prompt_input = self.tokenizer.encode(prompt)
|
||||
num_prompt_tokens = len(prompt_input)
|
||||
else:
|
||||
prompt_input = prompt
|
||||
num_prompt_tokens = len(self.tokenizer.encode(prompt))
|
||||
|
||||
if use_list_input:
|
||||
prompt_arg = [prompt_input, prompt_input]
|
||||
num_choices = len(prompt_arg)
|
||||
num_prompt_tokens *= 2
|
||||
else:
|
||||
prompt_arg = prompt_input
|
||||
num_choices = 1
|
||||
|
||||
generator = client.completions.create(
|
||||
model=self.model,
|
||||
prompt=prompt_arg,
|
||||
temperature=0,
|
||||
max_tokens=32,
|
||||
echo=echo,
|
||||
logprobs=logprobs,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
n=parallel_sample_num,
|
||||
)
|
||||
|
||||
is_firsts = {}
|
||||
for response in generator:
|
||||
usage = response.usage
|
||||
if usage is not None:
|
||||
assert usage.prompt_tokens > 0, f"usage.prompt_tokens was zero"
|
||||
assert usage.completion_tokens > 0, f"usage.completion_tokens was zero"
|
||||
assert usage.total_tokens > 0, f"usage.total_tokens was zero"
|
||||
continue
|
||||
|
||||
index = response.choices[0].index
|
||||
is_first = is_firsts.get(index, True)
|
||||
|
||||
if logprobs:
|
||||
assert response.choices[0].logprobs, f"no logprobs in response"
|
||||
assert isinstance(
|
||||
response.choices[0].logprobs.tokens[0], str
|
||||
), f"{response.choices[0].logprobs.tokens[0]} is not a string"
|
||||
if not (is_first and echo):
|
||||
assert isinstance(
|
||||
response.choices[0].logprobs.top_logprobs[0], dict
|
||||
), f"top_logprobs was not a dictionary"
|
||||
ret_num_top_logprobs = len(
|
||||
response.choices[0].logprobs.top_logprobs[0]
|
||||
)
|
||||
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map
|
||||
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
|
||||
assert ret_num_top_logprobs > 0, f"ret_num_top_logprobs was 0"
|
||||
|
||||
if is_first:
|
||||
if echo:
|
||||
assert response.choices[0].text.startswith(
|
||||
prompt
|
||||
), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {is_first}"
|
||||
is_firsts[index] = False
|
||||
assert response.id, f"no id in response"
|
||||
assert response.created, f"no created in response"
|
||||
|
||||
for index in [i for i in range(parallel_sample_num * num_choices)]:
|
||||
assert not is_firsts.get(
|
||||
index, True
|
||||
), f"index {index} is not found in the response"
|
||||
|
||||
def run_chat_completion(self, logprobs, parallel_sample_num):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
response = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the capital of France? Answer in a few words.",
|
||||
},
|
||||
],
|
||||
temperature=0,
|
||||
logprobs=logprobs is not None and logprobs > 0,
|
||||
top_logprobs=logprobs,
|
||||
n=parallel_sample_num,
|
||||
)
|
||||
|
||||
if logprobs:
|
||||
assert isinstance(
|
||||
response.choices[0].logprobs.content[0].top_logprobs[0].token, str
|
||||
)
|
||||
|
||||
ret_num_top_logprobs = len(
|
||||
response.choices[0].logprobs.content[0].top_logprobs
|
||||
)
|
||||
assert (
|
||||
ret_num_top_logprobs == logprobs
|
||||
), f"{ret_num_top_logprobs} vs {logprobs}"
|
||||
|
||||
assert len(response.choices) == parallel_sample_num
|
||||
assert response.choices[0].message.role == "assistant"
|
||||
assert isinstance(response.choices[0].message.content, str)
|
||||
assert response.id
|
||||
assert response.created
|
||||
assert response.usage.prompt_tokens > 0
|
||||
assert response.usage.completion_tokens > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
def run_chat_completion_stream(self, logprobs, parallel_sample_num=1):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
generator = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{"role": "user", "content": "What is the capital of France?"},
|
||||
],
|
||||
temperature=0,
|
||||
logprobs=logprobs is not None and logprobs > 0,
|
||||
top_logprobs=logprobs,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
n=parallel_sample_num,
|
||||
)
|
||||
|
||||
is_firsts = {}
|
||||
is_finished = {}
|
||||
for response in generator:
|
||||
usage = response.usage
|
||||
if usage is not None:
|
||||
assert usage.prompt_tokens > 0, f"usage.prompt_tokens was zero"
|
||||
assert usage.completion_tokens > 0, f"usage.completion_tokens was zero"
|
||||
assert usage.total_tokens > 0, f"usage.total_tokens was zero"
|
||||
continue
|
||||
|
||||
index = response.choices[0].index
|
||||
finish_reason = response.choices[0].finish_reason
|
||||
if finish_reason is not None:
|
||||
is_finished[index] = True
|
||||
|
||||
data = response.choices[0].delta
|
||||
|
||||
if is_firsts.get(index, True):
|
||||
assert (
|
||||
data.role == "assistant"
|
||||
), f"data.role was not 'assistant' for first chunk"
|
||||
is_firsts[index] = False
|
||||
continue
|
||||
|
||||
if logprobs and not is_finished.get(index, False):
|
||||
assert response.choices[0].logprobs, f"logprobs was not returned"
|
||||
assert isinstance(
|
||||
response.choices[0].logprobs.content[0].top_logprobs[0].token, str
|
||||
), f"top_logprobs token was not a string"
|
||||
assert isinstance(
|
||||
response.choices[0].logprobs.content[0].top_logprobs, list
|
||||
), f"top_logprobs was not a list"
|
||||
ret_num_top_logprobs = len(
|
||||
response.choices[0].logprobs.content[0].top_logprobs
|
||||
)
|
||||
assert (
|
||||
ret_num_top_logprobs == logprobs
|
||||
), f"{ret_num_top_logprobs} vs {logprobs}"
|
||||
|
||||
assert (
|
||||
isinstance(data.content, str)
|
||||
or isinstance(data.reasoning_content, str)
|
||||
or (isinstance(data.tool_calls, list) and len(data.tool_calls) > 0)
|
||||
or response.choices[0].finish_reason
|
||||
)
|
||||
assert response.id
|
||||
assert response.created
|
||||
|
||||
for index in [i for i in range(parallel_sample_num)]:
|
||||
assert not is_firsts.get(
|
||||
index, True
|
||||
), f"index {index} is not found in the response"
|
||||
|
||||
def test_completion(self):
|
||||
for echo in [False, True]:
|
||||
for logprobs in [None, 5]:
|
||||
for use_list_input in [True, False]:
|
||||
for parallel_sample_num in [1, 2]:
|
||||
for token_input in [False, True]:
|
||||
self.run_completion(
|
||||
echo,
|
||||
logprobs,
|
||||
use_list_input,
|
||||
parallel_sample_num,
|
||||
token_input,
|
||||
)
|
||||
|
||||
def test_completion_stream(self):
|
||||
# parallel sampling and list input are not supported in streaming mode
|
||||
for echo in [False, True]:
|
||||
for logprobs in [None, 5]:
|
||||
for use_list_input in [True, False]:
|
||||
for parallel_sample_num in [1, 2]:
|
||||
for token_input in [False, True]:
|
||||
self.run_completion_stream(
|
||||
echo,
|
||||
logprobs,
|
||||
use_list_input,
|
||||
parallel_sample_num,
|
||||
token_input,
|
||||
)
|
||||
|
||||
def test_chat_completion(self):
|
||||
for logprobs in [None, 5]:
|
||||
for parallel_sample_num in [1, 2]:
|
||||
self.run_chat_completion(logprobs, parallel_sample_num)
|
||||
|
||||
def test_chat_completion_stream(self):
|
||||
for logprobs in [None, 5]:
|
||||
for parallel_sample_num in [1, 2]:
|
||||
self.run_chat_completion_stream(logprobs, parallel_sample_num)
|
||||
|
||||
def test_regex(self):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
regex = (
|
||||
r"""\{\n"""
|
||||
+ r""" "name": "[\w]+",\n"""
|
||||
+ r""" "population": [\d]+\n"""
|
||||
+ r"""\}"""
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{"role": "user", "content": "Introduce the capital of France."},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=128,
|
||||
extra_body={"regex": regex},
|
||||
)
|
||||
text = response.choices[0].message.content
|
||||
|
||||
try:
|
||||
js_obj = json.loads(text)
|
||||
except (TypeError, json.decoder.JSONDecodeError):
|
||||
print("JSONDecodeError", text)
|
||||
raise
|
||||
assert isinstance(js_obj["name"], str)
|
||||
assert isinstance(js_obj["population"], int)
|
||||
|
||||
def test_penalty(self):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{"role": "user", "content": "Introduce the capital of France."},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=32,
|
||||
frequency_penalty=1.0,
|
||||
)
|
||||
text = response.choices[0].message.content
|
||||
assert isinstance(text, str)
|
||||
|
||||
def test_response_prefill(self):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": """
|
||||
Extract the name, size, price, and color from this product description as a JSON object:
|
||||
|
||||
<description>
|
||||
The SmartHome Mini is a compact smart home assistant available in black or white for only $49.99. At just 5 inches wide, it lets you control lights, thermostats, and other connected devices via voice or app—no matter where you place it in your home. This affordable little hub brings convenient hands-free control to your smart devices.
|
||||
</description>
|
||||
""",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "{\n",
|
||||
},
|
||||
],
|
||||
temperature=0,
|
||||
extra_body={"continue_final_message": True},
|
||||
)
|
||||
|
||||
assert (
|
||||
response.choices[0]
|
||||
.message.content.strip()
|
||||
.startswith('"name": "SmartHome Mini",')
|
||||
)
|
||||
|
||||
def test_model_list(self):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
models = list(client.models.list())
|
||||
assert len(models) == 1
|
||||
assert isinstance(getattr(models[0], "max_model_len", None), int)
|
||||
|
||||
def test_retrieve_model(self):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
# Test retrieving an existing model
|
||||
retrieved_model = client.models.retrieve(self.model)
|
||||
self.assertEqual(retrieved_model.id, self.model)
|
||||
self.assertEqual(retrieved_model.root, self.model)
|
||||
|
||||
# Test retrieving a non-existent model
|
||||
with self.assertRaises(openai.NotFoundError):
|
||||
client.models.retrieve("non-existent-model")
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# EBNF Test Class: TestOpenAIServerEBNF
|
||||
# Launches the server with xgrammar, has only EBNF tests
|
||||
# -------------------------------------------------------------------------
|
||||
class TestOpenAIServerEBNF(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
|
||||
# passing xgrammar specifically
|
||||
other_args = ["--grammar-backend", "xgrammar"]
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
other_args=other_args,
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_ebnf(self):
|
||||
"""
|
||||
Ensure we can pass `ebnf` to the local openai server
|
||||
and that it enforces the grammar.
|
||||
"""
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
ebnf_grammar = r"""
|
||||
root ::= "Hello" | "Hi" | "Hey"
|
||||
"""
|
||||
pattern = re.compile(r"^(Hello|Hi|Hey)[.!?]*\s*$")
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful EBNF test bot."},
|
||||
{"role": "user", "content": "Say a greeting (Hello, Hi, or Hey)."},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=32,
|
||||
extra_body={"ebnf": ebnf_grammar},
|
||||
)
|
||||
text = response.choices[0].message.content.strip()
|
||||
self.assertTrue(len(text) > 0, "Got empty text from EBNF generation")
|
||||
self.assertRegex(text, pattern, f"Text '{text}' doesn't match EBNF choices")
|
||||
|
||||
def test_ebnf_strict_json(self):
|
||||
"""
|
||||
A stricter EBNF that produces exactly {"name":"Alice"} format
|
||||
with no trailing punctuation or extra fields.
|
||||
"""
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
ebnf_grammar = r"""
|
||||
root ::= "{" pair "}"
|
||||
pair ::= "\"name\"" ":" string
|
||||
string ::= "\"" [A-Za-z]+ "\""
|
||||
"""
|
||||
pattern = re.compile(r'^\{"name":"[A-Za-z]+"\}$')
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "EBNF mini-JSON generator."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Generate single key JSON with only letters.",
|
||||
},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=64,
|
||||
extra_body={"ebnf": ebnf_grammar},
|
||||
)
|
||||
text = response.choices[0].message.content.strip()
|
||||
self.assertTrue(len(text) > 0, "Got empty text from EBNF strict JSON test")
|
||||
self.assertRegex(
|
||||
text, pattern, f"Text '{text}' not matching the EBNF strict JSON shape"
|
||||
)
|
||||
|
||||
|
||||
class TestOpenAIV1Rerank(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_SMALL_CROSS_ENCODER_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
cls.score_tolerance = 1e-2
|
||||
|
||||
# Configure embedding-specific args
|
||||
other_args = [
|
||||
"--is-embedding",
|
||||
"--enable-metrics",
|
||||
"--disable-radix-cache",
|
||||
"--chunked-prefill-size",
|
||||
"-1",
|
||||
"--attention-backend",
|
||||
"torch_native",
|
||||
]
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
other_args=other_args,
|
||||
)
|
||||
cls.base_url += "/v1/rerank"
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def run_rerank(self, query, docs):
|
||||
response = requests.post(
|
||||
self.base_url,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={"query": query, "documents": docs},
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
def test_rerank_single(self):
|
||||
"""Test single rerank request"""
|
||||
query = TEST_RERANK_QUERY_DOCS[0]["query"]
|
||||
docs = TEST_RERANK_QUERY_DOCS[0]["documents"]
|
||||
|
||||
response = self.run_rerank(query, docs)
|
||||
|
||||
self.assertEqual(len(response), 1)
|
||||
self.assertTrue(isinstance(response[0]["score"], float))
|
||||
self.assertTrue(isinstance(response[0]["document"], str))
|
||||
self.assertTrue(isinstance(response[0]["index"], int))
|
||||
|
||||
def test_rerank_batch(self):
|
||||
"""Test batch rerank request"""
|
||||
query = TEST_RERANK_QUERY_DOCS[1]["query"]
|
||||
docs = TEST_RERANK_QUERY_DOCS[1]["documents"]
|
||||
|
||||
response = self.run_rerank(query, docs)
|
||||
|
||||
self.assertEqual(len(response), 2)
|
||||
self.assertTrue(isinstance(response[0]["score"], float))
|
||||
self.assertTrue(isinstance(response[1]["score"], float))
|
||||
self.assertTrue(isinstance(response[0]["document"], str))
|
||||
self.assertTrue(isinstance(response[1]["document"], str))
|
||||
self.assertTrue(isinstance(response[0]["index"], int))
|
||||
self.assertTrue(isinstance(response[1]["index"], int))
|
||||
|
||||
|
||||
class TestOpenAIV1Score(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
)
|
||||
cls.base_url += "/v1/score"
|
||||
cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def run_score(
|
||||
self, query, items, label_token_ids, apply_softmax=False, item_first=False
|
||||
):
|
||||
response = requests.post(
|
||||
self.base_url,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"model": self.model,
|
||||
"query": query,
|
||||
"items": items,
|
||||
"label_token_ids": label_token_ids,
|
||||
"apply_softmax": apply_softmax,
|
||||
"item_first": item_first,
|
||||
},
|
||||
)
|
||||
return response.json()
|
||||
|
||||
def test_score_text_input(self):
|
||||
"""Test scoring with text input"""
|
||||
query = "The capital of France is"
|
||||
items = ["Paris", "London", "Berlin"]
|
||||
|
||||
# Get valid token IDs from the tokenizer
|
||||
label_token_ids = []
|
||||
for item in items:
|
||||
token_ids = self.tokenizer.encode(item, add_special_tokens=False)
|
||||
if not token_ids:
|
||||
self.fail(f"Failed to encode item: {item}")
|
||||
label_token_ids.append(token_ids[0])
|
||||
|
||||
response = self.run_score(query, items, label_token_ids, apply_softmax=True)
|
||||
|
||||
# Handle error responses
|
||||
if response.get("type") == "BadRequestError":
|
||||
self.fail(f"Score request failed with error: {response['message']}")
|
||||
|
||||
# Verify response structure
|
||||
self.assertIn("scores", response, "Response should have a 'scores' field")
|
||||
self.assertIsInstance(response["scores"], list, "scores should be a list")
|
||||
self.assertEqual(
|
||||
len(response["scores"]),
|
||||
len(items),
|
||||
"Number of scores should match number of items",
|
||||
)
|
||||
|
||||
# Each score should be a list of floats in the order of label_token_ids
|
||||
for i, score_list in enumerate(response["scores"]):
|
||||
self.assertIsInstance(score_list, list, f"Score {i} should be a list")
|
||||
self.assertEqual(
|
||||
len(score_list),
|
||||
len(label_token_ids),
|
||||
f"Score {i} length should match label_token_ids",
|
||||
)
|
||||
self.assertTrue(
|
||||
all(isinstance(v, float) for v in score_list),
|
||||
f"Score {i} values should be floats",
|
||||
)
|
||||
self.assertAlmostEqual(
|
||||
sum(score_list),
|
||||
1.0,
|
||||
places=6,
|
||||
msg=f"Score {i} probabilities should sum to 1",
|
||||
)
|
||||
|
||||
def test_score_token_input(self):
|
||||
"""Test scoring with token IDs input"""
|
||||
query = "The capital of France is"
|
||||
items = ["Paris", "London", "Berlin"]
|
||||
|
||||
# Get valid token IDs
|
||||
query_ids = self.tokenizer.encode(query, add_special_tokens=False)
|
||||
item_ids = [
|
||||
self.tokenizer.encode(item, add_special_tokens=False) for item in items
|
||||
]
|
||||
label_token_ids = [
|
||||
ids[0] for ids in item_ids if ids
|
||||
] # Get first token ID of each item
|
||||
|
||||
response = self.run_score(
|
||||
query_ids, item_ids, label_token_ids, apply_softmax=True
|
||||
)
|
||||
|
||||
# Handle error responses
|
||||
if response.get("type") == "BadRequestError":
|
||||
self.fail(f"Score request failed with error: {response['message']}")
|
||||
|
||||
# Verify response structure
|
||||
self.assertIn("scores", response, "Response should have a 'scores' field")
|
||||
self.assertIsInstance(response["scores"], list, "scores should be a list")
|
||||
self.assertEqual(
|
||||
len(response["scores"]),
|
||||
len(items),
|
||||
"Number of scores should match number of items",
|
||||
)
|
||||
|
||||
# Each score should be a list of floats in the order of label_token_ids
|
||||
for i, score_list in enumerate(response["scores"]):
|
||||
self.assertIsInstance(score_list, list, f"Score {i} should be a list")
|
||||
self.assertEqual(
|
||||
len(score_list),
|
||||
len(label_token_ids),
|
||||
f"Score {i} length should match label_token_ids",
|
||||
)
|
||||
self.assertTrue(
|
||||
all(isinstance(v, float) for v in score_list),
|
||||
f"Score {i} values should be floats",
|
||||
)
|
||||
self.assertAlmostEqual(
|
||||
sum(score_list),
|
||||
1.0,
|
||||
places=6,
|
||||
msg=f"Score {i} probabilities should sum to 1",
|
||||
)
|
||||
|
||||
def test_score_error_handling(self):
|
||||
"""Test error handling for invalid inputs"""
|
||||
query = "The capital of France is"
|
||||
items = ["Paris", "London", "Berlin"]
|
||||
|
||||
# Test with invalid token ID
|
||||
response = requests.post(
|
||||
self.base_url,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"model": self.model,
|
||||
"query": query,
|
||||
"items": items,
|
||||
"label_token_ids": [999999], # Invalid token ID
|
||||
"apply_softmax": True,
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
error_response = response.json()
|
||||
self.assertEqual(error_response["type"], "BadRequestError")
|
||||
self.assertIn("Token ID 999999 is out of vocabulary", error_response["message"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
279
test/srt/openai_server/basic/test_protocol.py
Normal file
279
test/srt/openai_server/basic/test_protocol.py
Normal file
@@ -0,0 +1,279 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for OpenAI API protocol models"""
|
||||
|
||||
import json
|
||||
import time
|
||||
import unittest
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import (
|
||||
BatchRequest,
|
||||
BatchResponse,
|
||||
ChatCompletionMessageContentImagePart,
|
||||
ChatCompletionMessageContentTextPart,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse,
|
||||
ChatCompletionTokenLogprob,
|
||||
ChatMessage,
|
||||
ChoiceLogprobs,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseChoice,
|
||||
DeltaMessage,
|
||||
EmbeddingObject,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
ErrorResponse,
|
||||
FileDeleteResponse,
|
||||
FileRequest,
|
||||
FileResponse,
|
||||
Function,
|
||||
FunctionResponse,
|
||||
JsonSchemaResponseFormat,
|
||||
LogProbs,
|
||||
ModelCard,
|
||||
ModelList,
|
||||
MultimodalEmbeddingInput,
|
||||
ResponseFormat,
|
||||
ScoringRequest,
|
||||
ScoringResponse,
|
||||
StreamOptions,
|
||||
StructuralTagResponseFormat,
|
||||
Tool,
|
||||
ToolCall,
|
||||
ToolChoice,
|
||||
TopLogprob,
|
||||
UsageInfo,
|
||||
)
|
||||
|
||||
|
||||
class TestModelCard(unittest.TestCase):
|
||||
"""Test ModelCard protocol model"""
|
||||
|
||||
def test_model_card_serialization(self):
|
||||
"""Test model card JSON serialization"""
|
||||
card = ModelCard(id="test-model", max_model_len=4096)
|
||||
data = card.model_dump()
|
||||
self.assertEqual(data["id"], "test-model")
|
||||
self.assertEqual(data["object"], "model")
|
||||
self.assertEqual(data["max_model_len"], 4096)
|
||||
|
||||
|
||||
class TestModelList(unittest.TestCase):
|
||||
"""Test ModelList protocol model"""
|
||||
|
||||
def test_empty_model_list(self):
|
||||
"""Test empty model list creation"""
|
||||
model_list = ModelList()
|
||||
self.assertEqual(model_list.object, "list")
|
||||
self.assertEqual(len(model_list.data), 0)
|
||||
|
||||
def test_model_list_with_cards(self):
|
||||
"""Test model list with model cards"""
|
||||
cards = [
|
||||
ModelCard(id="model-1"),
|
||||
ModelCard(id="model-2", max_model_len=2048),
|
||||
]
|
||||
model_list = ModelList(data=cards)
|
||||
self.assertEqual(len(model_list.data), 2)
|
||||
self.assertEqual(model_list.data[0].id, "model-1")
|
||||
self.assertEqual(model_list.data[1].id, "model-2")
|
||||
|
||||
|
||||
class TestCompletionRequest(unittest.TestCase):
|
||||
"""Test CompletionRequest protocol model"""
|
||||
|
||||
def test_basic_completion_request(self):
|
||||
"""Test basic completion request"""
|
||||
request = CompletionRequest(model="test-model", prompt="Hello world")
|
||||
self.assertEqual(request.model, "test-model")
|
||||
self.assertEqual(request.prompt, "Hello world")
|
||||
self.assertEqual(request.max_tokens, 16) # default
|
||||
self.assertEqual(request.temperature, 1.0) # default
|
||||
self.assertEqual(request.n, 1) # default
|
||||
self.assertFalse(request.stream) # default
|
||||
self.assertFalse(request.echo) # default
|
||||
|
||||
def test_completion_request_sglang_extensions(self):
|
||||
"""Test completion request with SGLang-specific extensions"""
|
||||
request = CompletionRequest(
|
||||
model="test-model",
|
||||
prompt="Hello",
|
||||
top_k=50,
|
||||
min_p=0.1,
|
||||
repetition_penalty=1.1,
|
||||
regex=r"\d+",
|
||||
json_schema='{"type": "object"}',
|
||||
lora_path="/path/to/lora",
|
||||
)
|
||||
self.assertEqual(request.top_k, 50)
|
||||
self.assertEqual(request.min_p, 0.1)
|
||||
self.assertEqual(request.repetition_penalty, 1.1)
|
||||
self.assertEqual(request.regex, r"\d+")
|
||||
self.assertEqual(request.json_schema, '{"type": "object"}')
|
||||
self.assertEqual(request.lora_path, "/path/to/lora")
|
||||
|
||||
def test_completion_request_validation_errors(self):
|
||||
"""Test completion request validation errors"""
|
||||
with self.assertRaises(ValidationError):
|
||||
CompletionRequest() # missing required fields
|
||||
|
||||
with self.assertRaises(ValidationError):
|
||||
CompletionRequest(model="test-model") # missing prompt
|
||||
|
||||
|
||||
class TestChatCompletionRequest(unittest.TestCase):
|
||||
"""Test ChatCompletionRequest protocol model"""
|
||||
|
||||
def test_basic_chat_completion_request(self):
|
||||
"""Test basic chat completion request"""
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
request = ChatCompletionRequest(model="test-model", messages=messages)
|
||||
self.assertEqual(request.model, "test-model")
|
||||
self.assertEqual(len(request.messages), 1)
|
||||
self.assertEqual(request.messages[0].role, "user")
|
||||
self.assertEqual(request.messages[0].content, "Hello")
|
||||
self.assertEqual(request.temperature, 0.7) # default
|
||||
self.assertFalse(request.stream) # default
|
||||
self.assertEqual(request.tool_choice, "none") # default when no tools
|
||||
|
||||
def test_chat_completion_tool_choice_validation(self):
|
||||
"""Test tool choice validation logic"""
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
|
||||
# No tools, tool_choice should default to "none"
|
||||
request1 = ChatCompletionRequest(model="test-model", messages=messages)
|
||||
self.assertEqual(request1.tool_choice, "none")
|
||||
|
||||
# With tools, tool_choice should default to "auto"
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {"name": "test_func", "description": "Test function"},
|
||||
}
|
||||
]
|
||||
request2 = ChatCompletionRequest(
|
||||
model="test-model", messages=messages, tools=tools
|
||||
)
|
||||
self.assertEqual(request2.tool_choice, "auto")
|
||||
|
||||
def test_chat_completion_sglang_extensions(self):
|
||||
"""Test chat completion with SGLang extensions"""
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
request = ChatCompletionRequest(
|
||||
model="test-model",
|
||||
messages=messages,
|
||||
top_k=40,
|
||||
min_p=0.05,
|
||||
separate_reasoning=False,
|
||||
stream_reasoning=False,
|
||||
chat_template_kwargs={"custom_param": "value"},
|
||||
)
|
||||
self.assertEqual(request.top_k, 40)
|
||||
self.assertEqual(request.min_p, 0.05)
|
||||
self.assertFalse(request.separate_reasoning)
|
||||
self.assertFalse(request.stream_reasoning)
|
||||
self.assertEqual(request.chat_template_kwargs, {"custom_param": "value"})
|
||||
|
||||
|
||||
class TestModelSerialization(unittest.TestCase):
|
||||
"""Test model serialization with hidden states"""
|
||||
|
||||
def test_hidden_states_excluded_when_none(self):
|
||||
"""Test that None hidden_states are excluded with exclude_none=True"""
|
||||
choice = ChatCompletionResponseChoice(
|
||||
index=0,
|
||||
message=ChatMessage(role="assistant", content="Hello"),
|
||||
finish_reason="stop",
|
||||
hidden_states=None,
|
||||
)
|
||||
|
||||
response = ChatCompletionResponse(
|
||||
id="test-id",
|
||||
model="test-model",
|
||||
choices=[choice],
|
||||
usage=UsageInfo(prompt_tokens=5, completion_tokens=1, total_tokens=6),
|
||||
)
|
||||
|
||||
# Test exclude_none serialization (should exclude None hidden_states)
|
||||
data = response.model_dump(exclude_none=True)
|
||||
self.assertNotIn("hidden_states", data["choices"][0])
|
||||
|
||||
def test_hidden_states_included_when_not_none(self):
|
||||
"""Test that non-None hidden_states are included"""
|
||||
choice = ChatCompletionResponseChoice(
|
||||
index=0,
|
||||
message=ChatMessage(role="assistant", content="Hello"),
|
||||
finish_reason="stop",
|
||||
hidden_states=[0.1, 0.2, 0.3],
|
||||
)
|
||||
|
||||
response = ChatCompletionResponse(
|
||||
id="test-id",
|
||||
model="test-model",
|
||||
choices=[choice],
|
||||
usage=UsageInfo(prompt_tokens=5, completion_tokens=1, total_tokens=6),
|
||||
)
|
||||
|
||||
# Test exclude_none serialization (should include non-None hidden_states)
|
||||
data = response.model_dump(exclude_none=True)
|
||||
self.assertIn("hidden_states", data["choices"][0])
|
||||
self.assertEqual(data["choices"][0]["hidden_states"], [0.1, 0.2, 0.3])
|
||||
|
||||
|
||||
class TestValidationEdgeCases(unittest.TestCase):
|
||||
"""Test edge cases and validation scenarios"""
|
||||
|
||||
def test_invalid_tool_choice_type(self):
|
||||
"""Test invalid tool choice type"""
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
with self.assertRaises(ValidationError):
|
||||
ChatCompletionRequest(
|
||||
model="test-model", messages=messages, tool_choice=123
|
||||
)
|
||||
|
||||
def test_negative_token_limits(self):
|
||||
"""Test negative token limits"""
|
||||
with self.assertRaises(ValidationError):
|
||||
CompletionRequest(model="test-model", prompt="Hello", max_tokens=-1)
|
||||
|
||||
def test_model_serialization_roundtrip(self):
|
||||
"""Test that models can be serialized and deserialized"""
|
||||
original_request = ChatCompletionRequest(
|
||||
model="test-model",
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
temperature=0.7,
|
||||
max_tokens=100,
|
||||
)
|
||||
|
||||
# Serialize to dict
|
||||
data = original_request.model_dump()
|
||||
|
||||
# Deserialize back
|
||||
restored_request = ChatCompletionRequest(**data)
|
||||
|
||||
self.assertEqual(restored_request.model, original_request.model)
|
||||
self.assertEqual(restored_request.temperature, original_request.temperature)
|
||||
self.assertEqual(restored_request.max_tokens, original_request.max_tokens)
|
||||
self.assertEqual(len(restored_request.messages), len(original_request.messages))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
146
test/srt/openai_server/basic/test_serving_chat.py
Normal file
146
test/srt/openai_server/basic/test_serving_chat.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""
|
||||
Unit-tests for OpenAIServingChat — rewritten to use only the std-lib 'unittest'.
|
||||
Run with either:
|
||||
python tests/test_serving_chat_unit.py -v
|
||||
or
|
||||
python -m unittest discover -s tests -p "test_*unit.py" -v
|
||||
"""
|
||||
|
||||
import unittest
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
|
||||
|
||||
class _MockTokenizerManager:
|
||||
"""Minimal mock that satisfies OpenAIServingChat."""
|
||||
|
||||
def __init__(self):
|
||||
self.model_config = Mock(is_multimodal=False)
|
||||
self.server_args = Mock(
|
||||
enable_cache_report=False,
|
||||
tool_call_parser="hermes",
|
||||
reasoning_parser=None,
|
||||
)
|
||||
self.chat_template_name: Optional[str] = "llama-3"
|
||||
|
||||
# tokenizer stub
|
||||
self.tokenizer = Mock()
|
||||
self.tokenizer.encode.return_value = [1, 2, 3, 4, 5]
|
||||
self.tokenizer.decode.return_value = "Test response"
|
||||
self.tokenizer.chat_template = None
|
||||
self.tokenizer.bos_token_id = 1
|
||||
|
||||
# async generator stub for generate_request
|
||||
async def _mock_generate():
|
||||
yield {
|
||||
"text": "Test response",
|
||||
"meta_info": {
|
||||
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 5,
|
||||
"cached_tokens": 0,
|
||||
"finish_reason": {"type": "stop", "matched": None},
|
||||
"output_token_logprobs": [(0.1, 1, "Test"), (0.2, 2, "response")],
|
||||
"output_top_logprobs": None,
|
||||
},
|
||||
"index": 0,
|
||||
}
|
||||
|
||||
self.generate_request = Mock(return_value=_mock_generate())
|
||||
self.create_abort_task = Mock()
|
||||
|
||||
|
||||
class _MockTemplateManager:
|
||||
"""Minimal mock for TemplateManager."""
|
||||
|
||||
def __init__(self):
|
||||
self.chat_template_name: Optional[str] = "llama-3"
|
||||
self.jinja_template_content_format: Optional[str] = None
|
||||
self.completion_template_name: Optional[str] = None
|
||||
|
||||
|
||||
class ServingChatTestCase(unittest.TestCase):
|
||||
# ------------- common fixtures -------------
|
||||
def setUp(self):
|
||||
self.tm = _MockTokenizerManager()
|
||||
self.template_manager = _MockTemplateManager()
|
||||
self.chat = OpenAIServingChat(self.tm, self.template_manager)
|
||||
|
||||
# frequently reused requests
|
||||
self.basic_req = ChatCompletionRequest(
|
||||
model="x",
|
||||
messages=[{"role": "user", "content": "Hi?"}],
|
||||
temperature=0.7,
|
||||
max_tokens=100,
|
||||
stream=False,
|
||||
)
|
||||
self.stream_req = ChatCompletionRequest(
|
||||
model="x",
|
||||
messages=[{"role": "user", "content": "Hi?"}],
|
||||
temperature=0.7,
|
||||
max_tokens=100,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
self.fastapi_request = Mock(spec=Request)
|
||||
self.fastapi_request.headers = {}
|
||||
|
||||
# ------------- conversion tests -------------
|
||||
def test_convert_to_internal_request_single(self):
|
||||
with patch(
|
||||
"sglang.srt.entrypoints.openai.serving_chat.generate_chat_conv"
|
||||
) as conv_mock, patch.object(self.chat, "_process_messages") as proc_mock:
|
||||
conv_ins = Mock()
|
||||
conv_ins.get_prompt.return_value = "Test prompt"
|
||||
conv_ins.image_data = conv_ins.audio_data = None
|
||||
conv_ins.modalities = []
|
||||
conv_ins.stop_str = ["</s>"]
|
||||
conv_mock.return_value = conv_ins
|
||||
|
||||
proc_mock.return_value = (
|
||||
"Test prompt",
|
||||
[1, 2, 3],
|
||||
None,
|
||||
None,
|
||||
[],
|
||||
["</s>"],
|
||||
None,
|
||||
)
|
||||
|
||||
adapted, processed = self.chat._convert_to_internal_request(self.basic_req)
|
||||
self.assertIsInstance(adapted, GenerateReqInput)
|
||||
self.assertFalse(adapted.stream)
|
||||
self.assertEqual(processed, self.basic_req)
|
||||
|
||||
# ------------- sampling-params -------------
|
||||
def test_sampling_param_build(self):
|
||||
req = ChatCompletionRequest(
|
||||
model="x",
|
||||
messages=[{"role": "user", "content": "Hi"}],
|
||||
temperature=0.8,
|
||||
max_tokens=150,
|
||||
min_tokens=5,
|
||||
top_p=0.9,
|
||||
stop=["</s>"],
|
||||
)
|
||||
with patch.object(
|
||||
self.chat,
|
||||
"_process_messages",
|
||||
return_value=("Prompt", [1], None, None, [], ["</s>"], None),
|
||||
):
|
||||
params = self.chat._build_sampling_params(req, ["</s>"], None)
|
||||
self.assertEqual(params["temperature"], 0.8)
|
||||
self.assertEqual(params["max_new_tokens"], 150)
|
||||
self.assertEqual(params["min_new_tokens"], 5)
|
||||
self.assertEqual(params["stop"], ["</s>"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
100
test/srt/openai_server/basic/test_serving_completions.py
Normal file
100
test/srt/openai_server/basic/test_serving_completions.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""
|
||||
Unit-tests for the refactored completions-serving handler (no pytest).
|
||||
Run with:
|
||||
python -m unittest tests.test_serving_completions_unit -v
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from typing import Optional
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import CompletionRequest
|
||||
from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompletion
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
|
||||
|
||||
class _MockTemplateManager:
|
||||
"""Minimal mock for TemplateManager."""
|
||||
|
||||
def __init__(self):
|
||||
self.chat_template_name: Optional[str] = None
|
||||
self.jinja_template_content_format: Optional[str] = None
|
||||
self.completion_template_name: Optional[str] = (
|
||||
None # Set to None to avoid template processing
|
||||
)
|
||||
|
||||
|
||||
class ServingCompletionTestCase(unittest.TestCase):
|
||||
"""Bundle all prompt/echo tests in one TestCase."""
|
||||
|
||||
# ---------- shared test fixtures ----------
|
||||
def setUp(self):
|
||||
# build the mock TokenizerManager once for every test
|
||||
tm = Mock(spec=TokenizerManager)
|
||||
|
||||
tm.tokenizer = Mock()
|
||||
tm.tokenizer.encode.return_value = [1, 2, 3, 4]
|
||||
tm.tokenizer.decode.return_value = "decoded text"
|
||||
tm.tokenizer.bos_token_id = 1
|
||||
|
||||
tm.model_config = Mock(is_multimodal=False)
|
||||
tm.server_args = Mock(enable_cache_report=False)
|
||||
|
||||
tm.generate_request = AsyncMock()
|
||||
tm.create_abort_task = Mock()
|
||||
|
||||
self.template_manager = _MockTemplateManager()
|
||||
self.sc = OpenAIServingCompletion(tm, self.template_manager)
|
||||
|
||||
# ---------- prompt-handling ----------
|
||||
def test_single_string_prompt(self):
|
||||
req = CompletionRequest(model="x", prompt="Hello world", max_tokens=100)
|
||||
internal, _ = self.sc._convert_to_internal_request(req)
|
||||
self.assertEqual(internal.text, "Hello world")
|
||||
|
||||
def test_single_token_ids_prompt(self):
|
||||
req = CompletionRequest(model="x", prompt=[1, 2, 3, 4], max_tokens=100)
|
||||
internal, _ = self.sc._convert_to_internal_request(req)
|
||||
self.assertEqual(internal.input_ids, [1, 2, 3, 4])
|
||||
|
||||
# ---------- echo-handling ----------
|
||||
def test_echo_with_string_prompt_streaming(self):
|
||||
req = CompletionRequest(model="x", prompt="Hello", max_tokens=1, echo=True)
|
||||
self.assertEqual(self.sc._get_echo_text(req, 0), "Hello")
|
||||
|
||||
def test_echo_with_list_of_strings_streaming(self):
|
||||
req = CompletionRequest(
|
||||
model="x", prompt=["A", "B"], max_tokens=1, echo=True, n=1
|
||||
)
|
||||
self.assertEqual(self.sc._get_echo_text(req, 0), "A")
|
||||
self.assertEqual(self.sc._get_echo_text(req, 1), "B")
|
||||
|
||||
def test_echo_with_token_ids_streaming(self):
|
||||
req = CompletionRequest(model="x", prompt=[1, 2, 3], max_tokens=1, echo=True)
|
||||
self.sc.tokenizer_manager.tokenizer.decode.return_value = "decoded_prompt"
|
||||
self.assertEqual(self.sc._get_echo_text(req, 0), "decoded_prompt")
|
||||
|
||||
def test_echo_with_multiple_token_ids_streaming(self):
|
||||
req = CompletionRequest(
|
||||
model="x", prompt=[[1, 2], [3, 4]], max_tokens=1, echo=True, n=1
|
||||
)
|
||||
self.sc.tokenizer_manager.tokenizer.decode.return_value = "decoded"
|
||||
self.assertEqual(self.sc._get_echo_text(req, 0), "decoded")
|
||||
|
||||
def test_prepare_echo_prompts_non_streaming(self):
|
||||
# single string
|
||||
req = CompletionRequest(model="x", prompt="Hi", echo=True)
|
||||
self.assertEqual(self.sc._prepare_echo_prompts(req), ["Hi"])
|
||||
|
||||
# list of strings
|
||||
req = CompletionRequest(model="x", prompt=["Hi", "Yo"], echo=True)
|
||||
self.assertEqual(self.sc._prepare_echo_prompts(req), ["Hi", "Yo"])
|
||||
|
||||
# token IDs
|
||||
req = CompletionRequest(model="x", prompt=[1, 2, 3], echo=True)
|
||||
self.sc.tokenizer_manager.tokenizer.decode.return_value = "decoded"
|
||||
self.assertEqual(self.sc._prepare_echo_prompts(req), ["decoded"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
145
test/srt/openai_server/basic/test_serving_embedding.py
Normal file
145
test/srt/openai_server/basic/test_serving_embedding.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""
|
||||
Unit tests for the OpenAIServingEmbedding class from serving_embedding.py.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
import uuid
|
||||
from unittest.mock import Mock
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import (
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
MultimodalEmbeddingInput,
|
||||
)
|
||||
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
from sglang.srt.managers.io_struct import EmbeddingReqInput
|
||||
|
||||
|
||||
# Mock TokenizerManager for embedding tests
|
||||
class _MockTokenizerManager:
|
||||
def __init__(self):
|
||||
self.model_config = Mock()
|
||||
self.model_config.is_multimodal = False
|
||||
self.server_args = Mock()
|
||||
self.server_args.enable_cache_report = False
|
||||
self.model_path = "test-model"
|
||||
|
||||
# Mock tokenizer
|
||||
self.tokenizer = Mock()
|
||||
self.tokenizer.encode = Mock(return_value=[1, 2, 3, 4, 5])
|
||||
self.tokenizer.decode = Mock(return_value="Test embedding input")
|
||||
self.tokenizer.chat_template = None
|
||||
self.tokenizer.bos_token_id = 1
|
||||
|
||||
# Mock generate_request method for embeddings
|
||||
async def mock_generate_embedding():
|
||||
yield {
|
||||
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5] * 20, # 100-dim embedding
|
||||
"meta_info": {
|
||||
"id": f"embd-{uuid.uuid4()}",
|
||||
"prompt_tokens": 5,
|
||||
},
|
||||
}
|
||||
|
||||
self.generate_request = Mock(return_value=mock_generate_embedding())
|
||||
|
||||
|
||||
# Mock TemplateManager for embedding tests
|
||||
class _MockTemplateManager:
|
||||
def __init__(self):
|
||||
self.chat_template_name = None # None for embeddings usually
|
||||
self.jinja_template_content_format = None
|
||||
self.completion_template_name = None
|
||||
|
||||
|
||||
class ServingEmbeddingTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.tokenizer_manager = _MockTokenizerManager()
|
||||
self.template_manager = _MockTemplateManager()
|
||||
self.serving_embedding = OpenAIServingEmbedding(
|
||||
self.tokenizer_manager, self.template_manager
|
||||
)
|
||||
|
||||
self.request = Mock(spec=Request)
|
||||
self.request.headers = {}
|
||||
|
||||
self.basic_req = EmbeddingRequest(
|
||||
model="test-model",
|
||||
input="Hello, how are you?",
|
||||
encoding_format="float",
|
||||
)
|
||||
self.list_req = EmbeddingRequest(
|
||||
model="test-model",
|
||||
input=["Hello, how are you?", "I am fine, thank you!"],
|
||||
encoding_format="float",
|
||||
)
|
||||
self.multimodal_req = EmbeddingRequest(
|
||||
model="test-model",
|
||||
input=[
|
||||
MultimodalEmbeddingInput(text="Hello", image="base64_image_data"),
|
||||
MultimodalEmbeddingInput(text="World", image=None),
|
||||
],
|
||||
encoding_format="float",
|
||||
)
|
||||
self.token_ids_req = EmbeddingRequest(
|
||||
model="test-model",
|
||||
input=[1, 2, 3, 4, 5],
|
||||
encoding_format="float",
|
||||
)
|
||||
|
||||
def test_convert_single_string_request(self):
|
||||
"""Test converting single string request to internal format."""
|
||||
adapted_request, processed_request = (
|
||||
self.serving_embedding._convert_to_internal_request(self.basic_req)
|
||||
)
|
||||
|
||||
self.assertIsInstance(adapted_request, EmbeddingReqInput)
|
||||
self.assertEqual(adapted_request.text, "Hello, how are you?")
|
||||
# self.assertEqual(adapted_request.rid, "test-id")
|
||||
self.assertEqual(processed_request, self.basic_req)
|
||||
|
||||
def test_convert_list_string_request(self):
|
||||
"""Test converting list of strings request to internal format."""
|
||||
adapted_request, processed_request = (
|
||||
self.serving_embedding._convert_to_internal_request(self.list_req)
|
||||
)
|
||||
|
||||
self.assertIsInstance(adapted_request, EmbeddingReqInput)
|
||||
self.assertEqual(
|
||||
adapted_request.text, ["Hello, how are you?", "I am fine, thank you!"]
|
||||
)
|
||||
# self.assertEqual(adapted_request.rid, "test-id")
|
||||
self.assertEqual(processed_request, self.list_req)
|
||||
|
||||
def test_convert_token_ids_request(self):
|
||||
"""Test converting token IDs request to internal format."""
|
||||
adapted_request, processed_request = (
|
||||
self.serving_embedding._convert_to_internal_request(self.token_ids_req)
|
||||
)
|
||||
|
||||
self.assertIsInstance(adapted_request, EmbeddingReqInput)
|
||||
self.assertEqual(adapted_request.input_ids, [1, 2, 3, 4, 5])
|
||||
# self.assertEqual(adapted_request.rid, "test-id")
|
||||
self.assertEqual(processed_request, self.token_ids_req)
|
||||
|
||||
def test_convert_multimodal_request(self):
|
||||
"""Test converting multimodal request to internal format."""
|
||||
adapted_request, processed_request = (
|
||||
self.serving_embedding._convert_to_internal_request(self.multimodal_req)
|
||||
)
|
||||
|
||||
self.assertIsInstance(adapted_request, EmbeddingReqInput)
|
||||
# Should extract text and images separately
|
||||
self.assertEqual(len(adapted_request.text), 2)
|
||||
self.assertIn("Hello", adapted_request.text)
|
||||
self.assertIn("World", adapted_request.text)
|
||||
self.assertEqual(adapted_request.image_data[0], "base64_image_data")
|
||||
self.assertIsNone(adapted_request.image_data[1])
|
||||
# self.assertEqual(adapted_request.rid, "test-id")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
0
test/srt/openai_server/features/__init__.py
Normal file
0
test/srt/openai_server/features/__init__.py
Normal file
211
test/srt/openai_server/features/test_cache_report.py
Normal file
211
test/srt/openai_server/features/test_cache_report.py
Normal file
@@ -0,0 +1,211 @@
|
||||
import asyncio
|
||||
import unittest
|
||||
|
||||
import openai
|
||||
import requests
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestCacheReport(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.min_cached = 5
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=300,
|
||||
other_args=[
|
||||
"--chunked-prefill-size=40",
|
||||
"--enable-cache-report",
|
||||
],
|
||||
)
|
||||
cls.client = openai.Client(api_key="EMPTY", base_url=f"{cls.base_url}/v1")
|
||||
cls.aclient = openai.AsyncClient(api_key="EMPTY", base_url=f"{cls.base_url}/v1")
|
||||
|
||||
usage = cls.run_openai(cls, "1").usage
|
||||
# we can assume that our request is of size 1, plus the total template size
|
||||
# ideally we would like to know the begin size / end size of the template to be more precise
|
||||
total_template_size = usage.prompt_tokens - 1
|
||||
print(f"template size: {total_template_size}")
|
||||
usage2 = cls.run_openai(cls, "2").usage
|
||||
assert usage2.prompt_tokens_details.cached_tokens <= total_template_size
|
||||
cls.min_cached = max(
|
||||
usage2.prompt_tokens_details.cached_tokens,
|
||||
total_template_size - usage2.prompt_tokens_details.cached_tokens,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1):
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
# we use an uncommon start to minimise the chance that the cache is hit by chance
|
||||
json={
|
||||
"text": "_ The capital of France is",
|
||||
"sampling_params": {
|
||||
"temperature": 0 if n == 1 else 0.5,
|
||||
"max_new_tokens": 128,
|
||||
"n": n,
|
||||
"stop_token_ids": [119690],
|
||||
},
|
||||
"stream": False,
|
||||
"return_logprob": return_logprob,
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
"logprob_start_len": 0,
|
||||
},
|
||||
)
|
||||
return response
|
||||
|
||||
def run_openai(self, message):
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
# {"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{"role": "user", "content": message},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=100,
|
||||
)
|
||||
return response
|
||||
|
||||
async def run_openai_async(self, message):
|
||||
response = await self.aclient.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "user", "content": message},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=100,
|
||||
)
|
||||
return response
|
||||
|
||||
def cache_report_openai(self, message):
|
||||
response = self.run_openai(message)
|
||||
print(
|
||||
f"openai first request cached_tokens: {int(response.usage.prompt_tokens_details.cached_tokens)}"
|
||||
)
|
||||
first_cached_tokens = int(response.usage.prompt_tokens_details.cached_tokens)
|
||||
# assert int(response.usage.cached_tokens) == 0
|
||||
assert first_cached_tokens <= self.min_cached
|
||||
response = self.run_openai(message)
|
||||
cached_tokens = int(response.usage.prompt_tokens_details.cached_tokens)
|
||||
print(f"openai second request cached_tokens: {cached_tokens}")
|
||||
assert cached_tokens > 0
|
||||
assert cached_tokens == int(response.usage.prompt_tokens) - 1
|
||||
return first_cached_tokens
|
||||
|
||||
async def cache_report_openai_async(self, message):
|
||||
response = await self.run_openai_async(message)
|
||||
cached_tokens = int(response.usage.prompt_tokens_details.cached_tokens)
|
||||
prompt_tokens = int(response.usage.prompt_tokens)
|
||||
return cached_tokens, prompt_tokens
|
||||
|
||||
def test_generate(self):
|
||||
print("=" * 100)
|
||||
response = self.run_decode()
|
||||
# print(response.json())
|
||||
cached_tokens = int(response.json()["meta_info"]["cached_tokens"])
|
||||
print(f"sglang first request cached_tokens: {cached_tokens}")
|
||||
print(
|
||||
f"sglang first request prompt_tokens: {int(response.json()['meta_info']['prompt_tokens'])}"
|
||||
)
|
||||
# can't assure to be 0: depends on the initialisation request / if a template is used with the model
|
||||
assert cached_tokens < self.min_cached
|
||||
response = self.run_decode()
|
||||
cached_tokens = int(response.json()["meta_info"]["cached_tokens"])
|
||||
print(f"sglang second request cached_tokens: {cached_tokens}")
|
||||
print(
|
||||
f"sglang second request prompt_tokens: {int(response.json()['meta_info']['prompt_tokens'])}"
|
||||
)
|
||||
assert cached_tokens == int(response.json()["meta_info"]["prompt_tokens"]) - 1
|
||||
|
||||
def test_cache_split_prefill_openai(self):
|
||||
print("=" * 100)
|
||||
self.cache_report_openai(
|
||||
"€ This is a very long and unique text that should not be already cached, the twist is"
|
||||
" that it should be longer than the chunked-prefill-size, so it should be split among"
|
||||
" several prefill requests. Still, it shouldn't be cached"
|
||||
)
|
||||
|
||||
def test_cache_report_openai(self):
|
||||
print("=" * 100)
|
||||
# warm up the cache, for the template
|
||||
self.run_openai("Introduce the capital of France.")
|
||||
|
||||
first_cached_tokens_1 = self.run_openai(
|
||||
"How many sparrow do you need to lift a coconut?"
|
||||
).usage.prompt_tokens_details.cached_tokens
|
||||
|
||||
usage_2 = self.run_openai("* sing something about cats").usage
|
||||
first_cached_tokens_2 = usage_2.prompt_tokens_details.cached_tokens
|
||||
# first request may not have 0 cached tokens, but if they only have the template in common they
|
||||
# should be the same once the cache is warmed up
|
||||
assert first_cached_tokens_1 == first_cached_tokens_2
|
||||
|
||||
resp = self.run_openai("* sing something about cats and dogs")
|
||||
print(resp.usage)
|
||||
|
||||
resp = self.run_openai("* sing something about cats, please")
|
||||
print(resp.usage)
|
||||
assert (
|
||||
resp.usage.prompt_tokens_details.cached_tokens
|
||||
>= usage_2.prompt_tokens - self.min_cached
|
||||
)
|
||||
|
||||
def test_cache_report_openai_async(self):
|
||||
print("=" * 100)
|
||||
|
||||
async def run_test():
|
||||
task0 = asyncio.create_task(
|
||||
self.cache_report_openai_async(
|
||||
"first request, to start the inference and let the next two request be started in the same batch"
|
||||
)
|
||||
)
|
||||
await asyncio.sleep(0.05) # to force the first request to be started first
|
||||
task1 = asyncio.create_task(
|
||||
self.cache_report_openai_async(
|
||||
"> can the same batch parallel request use the cache?"
|
||||
)
|
||||
)
|
||||
task2 = asyncio.create_task(
|
||||
self.cache_report_openai_async(
|
||||
"> can the same batch parallel request use the cache?"
|
||||
)
|
||||
)
|
||||
result0, result1, result2 = await asyncio.gather(task0, task1, task2)
|
||||
|
||||
cached_tokens0, prompt_tokens0 = result0
|
||||
cached_tokens1, prompt_tokens1 = result1
|
||||
cached_tokens2, prompt_tokens2 = result2
|
||||
|
||||
print(
|
||||
f"Async request 0 - Cached tokens: {cached_tokens0}, Prompt tokens: {prompt_tokens0}"
|
||||
)
|
||||
print(
|
||||
f"Async request 1 - Cached tokens: {cached_tokens1}, Prompt tokens: {prompt_tokens1}"
|
||||
)
|
||||
print(
|
||||
f"Async request 2 - Cached tokens: {cached_tokens2}, Prompt tokens: {prompt_tokens2}"
|
||||
)
|
||||
|
||||
# Assert that no requests used the cache (becausefirst is alone, and the next two are in the same batch)
|
||||
# If a new optimisation limiting starting request with same prefix at the same time was added
|
||||
# to maximise the cache hit, this would not be true
|
||||
assert cached_tokens1 == cached_tokens2 == cached_tokens0
|
||||
|
||||
asyncio.run(run_test())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
188
test/srt/openai_server/features/test_enable_thinking.py
Normal file
188
test/srt/openai_server/features/test_enable_thinking.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""
|
||||
Usage:
|
||||
python3 -m unittest openai_server.features.test_enable_thinking.TestEnableThinking.test_chat_completion_with_reasoning
|
||||
python3 -m unittest openai_server.features.test_enable_thinking.TestEnableThinking.test_chat_completion_without_reasoning
|
||||
python3 -m unittest openai_server.features.test_enable_thinking.TestEnableThinking.test_stream_chat_completion_with_reasoning
|
||||
python3 -m unittest openai_server.features.test_enable_thinking.TestEnableThinking.test_stream_chat_completion_without_reasoning
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import openai
|
||||
import requests
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_ENABLE_THINKING_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestEnableThinking(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_ENABLE_THINKING_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-1234"
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
other_args=[
|
||||
"--reasoning-parser",
|
||||
"qwen3",
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_chat_completion_with_reasoning(self):
|
||||
# Test non-streaming with "enable_thinking": True, reasoning_content should not be empty
|
||||
client = requests.post(
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
headers={"Authorization": f"Bearer {self.api_key}"},
|
||||
json={
|
||||
"model": self.model,
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"temperature": 0,
|
||||
"separate_reasoning": True,
|
||||
"chat_template_kwargs": {"enable_thinking": True},
|
||||
},
|
||||
)
|
||||
|
||||
self.assertEqual(client.status_code, 200, f"Failed with: {client.text}")
|
||||
data = client.json()
|
||||
|
||||
self.assertIn("choices", data)
|
||||
self.assertTrue(len(data["choices"]) > 0)
|
||||
self.assertIn("message", data["choices"][0])
|
||||
self.assertIn("reasoning_content", data["choices"][0]["message"])
|
||||
self.assertIsNotNone(data["choices"][0]["message"]["reasoning_content"])
|
||||
|
||||
def test_chat_completion_without_reasoning(self):
|
||||
# Test non-streaming with "enable_thinking": False, reasoning_content should be empty
|
||||
client = requests.post(
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
headers={"Authorization": f"Bearer {self.api_key}"},
|
||||
json={
|
||||
"model": self.model,
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"temperature": 0,
|
||||
"separate_reasoning": True,
|
||||
"chat_template_kwargs": {"enable_thinking": False},
|
||||
},
|
||||
)
|
||||
|
||||
self.assertEqual(client.status_code, 200, f"Failed with: {client.text}")
|
||||
data = client.json()
|
||||
|
||||
self.assertIn("choices", data)
|
||||
self.assertTrue(len(data["choices"]) > 0)
|
||||
self.assertIn("message", data["choices"][0])
|
||||
|
||||
if "reasoning_content" in data["choices"][0]["message"]:
|
||||
self.assertIsNone(data["choices"][0]["message"]["reasoning_content"])
|
||||
|
||||
def test_stream_chat_completion_with_reasoning(self):
|
||||
# Test streaming with "enable_thinking": True, reasoning_content should not be empty
|
||||
response = requests.post(
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
headers={"Authorization": f"Bearer {self.api_key}"},
|
||||
json={
|
||||
"model": self.model,
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"temperature": 0,
|
||||
"separate_reasoning": True,
|
||||
"stream": True,
|
||||
"chat_template_kwargs": {"enable_thinking": True},
|
||||
},
|
||||
stream=True,
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 200, f"Failed with: {response.text}")
|
||||
|
||||
has_reasoning = False
|
||||
has_content = False
|
||||
|
||||
print("\n=== Stream With Reasoning ===")
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
line = line.decode("utf-8")
|
||||
if line.startswith("data:") and not line.startswith("data: [DONE]"):
|
||||
data = json.loads(line[6:])
|
||||
if "choices" in data and len(data["choices"]) > 0:
|
||||
delta = data["choices"][0].get("delta", {})
|
||||
|
||||
if "reasoning_content" in delta and delta["reasoning_content"]:
|
||||
has_reasoning = True
|
||||
|
||||
if "content" in delta and delta["content"]:
|
||||
has_content = True
|
||||
|
||||
self.assertTrue(
|
||||
has_reasoning,
|
||||
"The reasoning content is not included in the stream response",
|
||||
)
|
||||
self.assertTrue(
|
||||
has_content, "The stream response does not contain normal content"
|
||||
)
|
||||
|
||||
def test_stream_chat_completion_without_reasoning(self):
|
||||
# Test streaming with "enable_thinking": False, reasoning_content should be empty
|
||||
response = requests.post(
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
headers={"Authorization": f"Bearer {self.api_key}"},
|
||||
json={
|
||||
"model": self.model,
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"temperature": 0,
|
||||
"separate_reasoning": True,
|
||||
"stream": True,
|
||||
"chat_template_kwargs": {"enable_thinking": False},
|
||||
},
|
||||
stream=True,
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 200, f"Failed with: {response.text}")
|
||||
|
||||
has_reasoning = False
|
||||
has_content = False
|
||||
|
||||
print("\n=== Stream Without Reasoning ===")
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
line = line.decode("utf-8")
|
||||
if line.startswith("data:") and not line.startswith("data: [DONE]"):
|
||||
data = json.loads(line[6:])
|
||||
if "choices" in data and len(data["choices"]) > 0:
|
||||
delta = data["choices"][0].get("delta", {})
|
||||
|
||||
if "reasoning_content" in delta and delta["reasoning_content"]:
|
||||
has_reasoning = True
|
||||
|
||||
if "content" in delta and delta["content"]:
|
||||
has_content = True
|
||||
|
||||
self.assertFalse(
|
||||
has_reasoning,
|
||||
"The reasoning content should not be included in the stream response",
|
||||
)
|
||||
self.assertTrue(
|
||||
has_content, "The stream response does not contain normal content"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
153
test/srt/openai_server/features/test_json_constrained.py
Normal file
153
test/srt/openai_server/features/test_json_constrained.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""
|
||||
python3 -m unittest openai_server.features.test_json_constrained.TestJSONConstrainedOutlinesBackend.test_json_generate
|
||||
python3 -m unittest openai_server.features.test_json_constrained.TestJSONConstrainedXGrammarBackend.test_json_generate
|
||||
python3 -m unittest openai_server.features.test_json_constrained.TestJSONConstrainedLLGuidanceBackend.test_json_generate
|
||||
"""
|
||||
|
||||
import json
|
||||
import unittest
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import openai
|
||||
import requests
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
def setup_class(cls, backend: str):
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.json_schema = json.dumps(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "pattern": "^[\\w]+$"},
|
||||
"population": {"type": "integer"},
|
||||
},
|
||||
"required": ["name", "population"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
)
|
||||
|
||||
other_args = [
|
||||
"--max-running-requests",
|
||||
"10",
|
||||
"--grammar-backend",
|
||||
backend,
|
||||
]
|
||||
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=other_args,
|
||||
)
|
||||
|
||||
|
||||
class TestJSONConstrainedOutlinesBackend(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
setup_class(cls, backend="outlines")
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def run_decode(self, json_schema, return_logprob=False, top_logprobs_num=0, n=1):
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"text": "The capital of France is",
|
||||
"sampling_params": {
|
||||
"temperature": 0 if n == 1 else 0.5,
|
||||
"max_new_tokens": 128,
|
||||
"n": n,
|
||||
"stop_token_ids": [119690],
|
||||
"json_schema": json_schema,
|
||||
},
|
||||
"stream": False,
|
||||
"return_logprob": return_logprob,
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
"logprob_start_len": 0,
|
||||
},
|
||||
)
|
||||
ret = response.json()
|
||||
print(json.dumps(ret))
|
||||
print("=" * 100)
|
||||
|
||||
if not json_schema or json_schema == "INVALID":
|
||||
return
|
||||
|
||||
# Make sure the json output is valid
|
||||
try:
|
||||
js_obj = json.loads(ret["text"])
|
||||
except (TypeError, json.decoder.JSONDecodeError):
|
||||
raise
|
||||
|
||||
self.assertIsInstance(js_obj["name"], str)
|
||||
self.assertIsInstance(js_obj["population"], int)
|
||||
|
||||
def test_json_generate(self):
|
||||
self.run_decode(json_schema=self.json_schema)
|
||||
|
||||
def test_json_invalid(self):
|
||||
self.run_decode(json_schema="INVALID")
|
||||
|
||||
def test_json_openai(self):
|
||||
client = openai.Client(api_key="EMPTY", base_url=f"{self.base_url}/v1")
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Introduce the capital of France. Return in a JSON format.",
|
||||
},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=128,
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": {"name": "foo", "schema": json.loads(self.json_schema)},
|
||||
},
|
||||
)
|
||||
text = response.choices[0].message.content
|
||||
|
||||
try:
|
||||
js_obj = json.loads(text)
|
||||
except (TypeError, json.decoder.JSONDecodeError):
|
||||
print("JSONDecodeError", text)
|
||||
raise
|
||||
|
||||
self.assertIsInstance(js_obj["name"], str)
|
||||
self.assertIsInstance(js_obj["population"], int)
|
||||
|
||||
def test_mix_json_and_other(self):
|
||||
json_schemas = [None, None, self.json_schema, self.json_schema] * 10
|
||||
|
||||
with ThreadPoolExecutor(len(json_schemas)) as executor:
|
||||
list(executor.map(self.run_decode, json_schemas))
|
||||
|
||||
|
||||
class TestJSONConstrainedXGrammarBackend(TestJSONConstrainedOutlinesBackend):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
setup_class(cls, backend="xgrammar")
|
||||
|
||||
|
||||
class TestJSONConstrainedLLGuidanceBackend(TestJSONConstrainedOutlinesBackend):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
setup_class(cls, backend="llguidance")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
137
test/srt/openai_server/features/test_json_mode.py
Normal file
137
test/srt/openai_server/features/test_json_mode.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""
|
||||
python3 -m unittest openai_server.features.test_json_mode.TestJSONModeOutlines.test_json_mode_response
|
||||
python3 -m unittest openai_server.features.test_json_mode.TestJSONModeOutlines.test_json_mode_with_streaming
|
||||
|
||||
python3 -m unittest openai_server.features.test_json_mode.TestJSONModeXGrammar.test_json_mode_response
|
||||
python3 -m unittest openai_server.features.test_json_mode.TestJSONModeXGrammar.test_json_mode_with_streaming
|
||||
|
||||
python3 -m unittest openai_server.features.test_json_mode.TestJSONModeLLGuidance.test_json_mode_response
|
||||
python3 -m unittest openai_server.features.test_json_mode.TestJSONModeLLGuidance.test_json_mode_with_streaming
|
||||
"""
|
||||
|
||||
import json
|
||||
import unittest
|
||||
|
||||
import openai
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
def setup_class(cls, backend):
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
|
||||
other_args = [
|
||||
"--max-running-requests",
|
||||
"10",
|
||||
"--grammar-backend",
|
||||
backend,
|
||||
]
|
||||
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=other_args,
|
||||
)
|
||||
cls.client = openai.Client(api_key="EMPTY", base_url=f"{cls.base_url}/v1")
|
||||
|
||||
|
||||
class TestJSONModeOutlines(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
setup_class(cls, "outlines")
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_json_mode_response(self):
|
||||
"""Test that response_format json_object (also known as "json mode") produces valid JSON, even without a system prompt that mentions JSON."""
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
# We are deliberately omitting "That produces JSON" or similar phrases from the assistant prompt so that we don't have misleading test results
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful AI assistant that gives a short answer.",
|
||||
},
|
||||
{"role": "user", "content": "What is the capital of Bulgaria?"},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=128,
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
text = response.choices[0].message.content
|
||||
|
||||
print(f"Response ({len(text)} characters): {text}")
|
||||
|
||||
# Verify the response is valid JSON
|
||||
try:
|
||||
js_obj = json.loads(text)
|
||||
except json.JSONDecodeError as e:
|
||||
self.fail(f"Response is not valid JSON. Error: {e}. Response: {text}")
|
||||
|
||||
# Verify it's actually an object (dict)
|
||||
self.assertIsInstance(js_obj, dict, f"Response is not a JSON object: {text}")
|
||||
|
||||
def test_json_mode_with_streaming(self):
|
||||
"""Test that streaming with json_object response (also known as "json mode") format works correctly, even without a system prompt that mentions JSON."""
|
||||
stream = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
# We are deliberately omitting "That produces JSON" or similar phrases from the assistant prompt so that we don't have misleading test results
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful AI assistant that gives a short answer.",
|
||||
},
|
||||
{"role": "user", "content": "What is the capital of Bulgaria?"},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=128,
|
||||
response_format={"type": "json_object"},
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Collect all chunks
|
||||
chunks = []
|
||||
for chunk in stream:
|
||||
if chunk.choices[0].delta.content is not None:
|
||||
chunks.append(chunk.choices[0].delta.content)
|
||||
full_response = "".join(chunks)
|
||||
|
||||
print(
|
||||
f"Concatenated Response ({len(full_response)} characters): {full_response}"
|
||||
)
|
||||
|
||||
# Verify the combined response is valid JSON
|
||||
try:
|
||||
js_obj = json.loads(full_response)
|
||||
except json.JSONDecodeError as e:
|
||||
self.fail(
|
||||
f"Streamed response is not valid JSON. Error: {e}. Response: {full_response}"
|
||||
)
|
||||
|
||||
self.assertIsInstance(js_obj, dict)
|
||||
|
||||
|
||||
class TestJSONModeXGrammar(TestJSONModeOutlines):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
setup_class(cls, backend="xgrammar")
|
||||
|
||||
|
||||
class TestJSONModeLLGuidance(TestJSONModeOutlines):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
setup_class(cls, backend="llguidance")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
98
test/srt/openai_server/features/test_openai_server_ebnf.py
Normal file
98
test/srt/openai_server/features/test_openai_server_ebnf.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import re
|
||||
|
||||
import openai
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# EBNF Test Class: TestOpenAIServerEBNF
|
||||
# Launches the server with xgrammar, has only EBNF tests
|
||||
# -------------------------------------------------------------------------
|
||||
class TestOpenAIServerEBNF(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
|
||||
# passing xgrammar specifically
|
||||
other_args = ["--grammar-backend", "xgrammar"]
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
other_args=other_args,
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_ebnf(self):
|
||||
"""
|
||||
Ensure we can pass `ebnf` to the local openai server
|
||||
and that it enforces the grammar.
|
||||
"""
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
ebnf_grammar = r"""
|
||||
root ::= "Hello" | "Hi" | "Hey"
|
||||
"""
|
||||
pattern = re.compile(r"^(Hello|Hi|Hey)[.!?]*\s*$")
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful EBNF test bot."},
|
||||
{"role": "user", "content": "Say a greeting (Hello, Hi, or Hey)."},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=32,
|
||||
extra_body={"ebnf": ebnf_grammar},
|
||||
)
|
||||
text = response.choices[0].message.content.strip()
|
||||
self.assertTrue(len(text) > 0, "Got empty text from EBNF generation")
|
||||
self.assertRegex(text, pattern, f"Text '{text}' doesn't match EBNF choices")
|
||||
|
||||
def test_ebnf_strict_json(self):
|
||||
"""
|
||||
A stricter EBNF that produces exactly {"name":"Alice"} format
|
||||
with no trailing punctuation or extra fields.
|
||||
"""
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
ebnf_grammar = r"""
|
||||
root ::= "{" pair "}"
|
||||
pair ::= "\"name\"" ":" string
|
||||
string ::= "\"" [A-Za-z]+ "\""
|
||||
"""
|
||||
pattern = re.compile(r'^\{"name":"[A-Za-z]+"\}$')
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "EBNF mini-JSON generator."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Generate single key JSON with only letters.",
|
||||
},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=64,
|
||||
extra_body={"ebnf": ebnf_grammar},
|
||||
)
|
||||
text = response.choices[0].message.content.strip()
|
||||
self.assertTrue(len(text) > 0, "Got empty text from EBNF strict JSON test")
|
||||
self.assertRegex(
|
||||
text, pattern, f"Text '{text}' not matching the EBNF strict JSON shape"
|
||||
)
|
||||
@@ -0,0 +1,356 @@
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import unittest
|
||||
from abc import ABC
|
||||
|
||||
import numpy as np
|
||||
import openai
|
||||
import torch
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
|
||||
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
|
||||
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class BaseTestOpenAIServerWithHiddenStates(ABC):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.return_hidden_states = [False, True]
|
||||
cls.use_list_input = [True, False]
|
||||
cls.parallel_sample_nums = [1, 2]
|
||||
|
||||
def test_completion(self):
|
||||
for return_hidden_states in self.return_hidden_states:
|
||||
for use_list_input in self.use_list_input:
|
||||
for parallel_sample_num in self.parallel_sample_nums:
|
||||
self.run_completion(
|
||||
use_list_input,
|
||||
parallel_sample_num,
|
||||
return_hidden_states,
|
||||
)
|
||||
|
||||
def test_completion_stream(self):
|
||||
# parallel sampling and list input are not supported in streaming mode
|
||||
for return_hidden_states in self.return_hidden_states:
|
||||
for use_list_input in self.use_list_input:
|
||||
for parallel_sample_num in self.parallel_sample_nums:
|
||||
self.run_completion_stream(
|
||||
use_list_input,
|
||||
parallel_sample_num,
|
||||
return_hidden_states,
|
||||
)
|
||||
|
||||
def test_chat_completion(self):
|
||||
for return_hidden_states in self.return_hidden_states:
|
||||
for (
|
||||
parallel_sample_num
|
||||
) in (
|
||||
self.parallel_sample_nums
|
||||
): # parallel sample num 2 breaks in the adapter with a 400 for EAGLE
|
||||
self.run_chat_completion(parallel_sample_num, return_hidden_states)
|
||||
|
||||
def test_chat_completion_stream(self):
|
||||
for return_hidden_states in self.return_hidden_states:
|
||||
for (
|
||||
parallel_sample_num
|
||||
) in (
|
||||
self.parallel_sample_nums
|
||||
): # parallel sample num > 1 breaks in the adapter with a 400 for EAGLE
|
||||
self.run_chat_completion_stream(
|
||||
parallel_sample_num, return_hidden_states
|
||||
)
|
||||
|
||||
def run_completion(
|
||||
self,
|
||||
use_list_input,
|
||||
parallel_sample_num,
|
||||
return_hidden_states,
|
||||
):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
prompt = "The capital of France is"
|
||||
prompt_input = prompt
|
||||
|
||||
if use_list_input:
|
||||
prompt_arg = [prompt_input, prompt_input]
|
||||
num_choices = len(prompt_arg)
|
||||
else:
|
||||
prompt_arg = prompt_input
|
||||
num_choices = 1
|
||||
|
||||
response = client.completions.create(
|
||||
model=self.model,
|
||||
prompt=prompt_arg,
|
||||
temperature=0,
|
||||
max_tokens=32,
|
||||
n=parallel_sample_num,
|
||||
extra_body=dict(return_hidden_states=return_hidden_states),
|
||||
)
|
||||
|
||||
for choice in response.choices:
|
||||
assert hasattr(choice, "hidden_states") == return_hidden_states
|
||||
if return_hidden_states:
|
||||
assert choice.hidden_states is not None, "hidden_states was None"
|
||||
|
||||
def run_completion_stream(
|
||||
self,
|
||||
use_list_input,
|
||||
parallel_sample_num,
|
||||
return_hidden_states,
|
||||
):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
prompt = "The capital of France is"
|
||||
prompt_input = prompt
|
||||
num_prompt_tokens = len(self.tokenizer.encode(prompt))
|
||||
|
||||
if use_list_input:
|
||||
prompt_arg = [prompt_input, prompt_input]
|
||||
num_choices = len(prompt_arg)
|
||||
num_prompt_tokens *= 2
|
||||
else:
|
||||
prompt_arg = prompt_input
|
||||
num_choices = 1
|
||||
|
||||
generator = client.completions.create(
|
||||
model=self.model,
|
||||
prompt=prompt_arg,
|
||||
temperature=0,
|
||||
max_tokens=32,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
n=parallel_sample_num,
|
||||
extra_body=dict(return_hidden_states=return_hidden_states),
|
||||
)
|
||||
|
||||
hidden_states_list = []
|
||||
for response in generator:
|
||||
usage = response.usage
|
||||
for choice in response.choices:
|
||||
if hasattr(choice, "hidden_states"):
|
||||
assert return_hidden_states
|
||||
assert choice.hidden_states is not None
|
||||
hidden_states_list.append(choice.hidden_states)
|
||||
|
||||
if return_hidden_states:
|
||||
assert (
|
||||
len(hidden_states_list) == parallel_sample_num * num_choices
|
||||
), f"Expected {parallel_sample_num * num_choices} hidden states, got {len(hidden_states_list)}"
|
||||
else:
|
||||
assert (
|
||||
hidden_states_list == []
|
||||
), "hidden_states were returned and should not have been"
|
||||
|
||||
def run_chat_completion(self, parallel_sample_num, return_hidden_states):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
response = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the capital of France? Answer in a few words.",
|
||||
},
|
||||
],
|
||||
temperature=0,
|
||||
n=parallel_sample_num,
|
||||
extra_body=dict(return_hidden_states=return_hidden_states),
|
||||
)
|
||||
|
||||
for choice in response.choices:
|
||||
assert hasattr(choice, "hidden_states") == return_hidden_states
|
||||
if return_hidden_states:
|
||||
assert choice.hidden_states is not None, "hidden_states was None"
|
||||
|
||||
def run_chat_completion_stream(
|
||||
self, parallel_sample_num=1, return_hidden_states=False
|
||||
):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
generator = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{"role": "user", "content": "What is the capital of France?"},
|
||||
],
|
||||
temperature=0,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
n=parallel_sample_num,
|
||||
extra_body=dict(return_hidden_states=return_hidden_states),
|
||||
)
|
||||
|
||||
is_firsts = {}
|
||||
hidden_states_list = []
|
||||
|
||||
for response in generator:
|
||||
for choice in response.choices:
|
||||
if hasattr(choice.delta, "hidden_states"):
|
||||
assert return_hidden_states
|
||||
assert choice.delta.hidden_states is not None
|
||||
hidden_states_list.append(choice.delta.hidden_states)
|
||||
|
||||
if return_hidden_states:
|
||||
assert (
|
||||
len(hidden_states_list) == parallel_sample_num
|
||||
), f"Expected {parallel_sample_num} hidden states, got {len(hidden_states_list)}"
|
||||
else:
|
||||
assert (
|
||||
hidden_states_list == []
|
||||
), "hidden_states were returned and should not have been"
|
||||
|
||||
|
||||
class TestOpenAIServerWithHiddenStatesEnabled(
|
||||
CustomTestCase, BaseTestOpenAIServerWithHiddenStates
|
||||
):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
other_args=["--enable-return-hidden-states"],
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
||||
cls.return_hidden_states = [False, True]
|
||||
cls.use_list_input = [True, False]
|
||||
cls.parallel_sample_nums = [1, 2]
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
|
||||
class TestOpenAIServerWithHiddenStatesEnabledAndCUDAGraphDisabled(
|
||||
CustomTestCase, BaseTestOpenAIServerWithHiddenStates
|
||||
):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
other_args=["--enable-return-hidden-states", "--disable-cuda-graph"],
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
||||
cls.return_hidden_states = [False, True]
|
||||
cls.use_list_input = [True, False]
|
||||
cls.parallel_sample_nums = [1]
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
|
||||
class TestOpenAIServerWithEAGLEAndHiddenStatesEnabled(
|
||||
CustomTestCase, BaseTestOpenAIServerWithHiddenStates
|
||||
):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
cls.speculative_draft_model = DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
|
||||
cls.speculative_algorithm = "EAGLE"
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--speculative-algorithm",
|
||||
"EAGLE",
|
||||
"--speculative-draft-model-path",
|
||||
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
|
||||
"--speculative-num-steps",
|
||||
5,
|
||||
"--speculative-eagle-topk",
|
||||
8,
|
||||
"--speculative-num-draft-tokens",
|
||||
64,
|
||||
"--mem-fraction-static",
|
||||
0.7,
|
||||
"--chunked-prefill-size",
|
||||
128,
|
||||
"--max-running-requests",
|
||||
8,
|
||||
"--enable-return-hidden-states",
|
||||
],
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
cls.tokenizer = get_tokenizer(DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST)
|
||||
cls.return_hidden_states = [False, True]
|
||||
cls.use_list_input = [True, False]
|
||||
cls.parallel_sample_nums = [1]
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
|
||||
class TestOpenAIServerWithEAGLE3AndHiddenStatesEnabled(
|
||||
CustomTestCase, BaseTestOpenAIServerWithHiddenStates
|
||||
):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
cls.speculative_algorithm = "EAGLE3"
|
||||
cls.speculative_draft_model = "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B"
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--speculative-algorithm",
|
||||
cls.speculative_algorithm,
|
||||
"--speculative-draft-model-path",
|
||||
cls.speculative_draft_model,
|
||||
"--speculative-num-steps",
|
||||
5,
|
||||
"--speculative-eagle-topk",
|
||||
16,
|
||||
"--speculative-num-draft-tokens",
|
||||
64,
|
||||
"--mem-fraction-static",
|
||||
0.7,
|
||||
"--chunked-prefill-size",
|
||||
128,
|
||||
"--max-running-requests",
|
||||
8,
|
||||
"--dtype",
|
||||
"float16",
|
||||
"--enable-return-hidden-states",
|
||||
],
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
cls.tokenizer = get_tokenizer(cls.model)
|
||||
cls.return_hidden_states = [False, True]
|
||||
cls.use_list_input = [True, False]
|
||||
cls.parallel_sample_nums = [1]
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
343
test/srt/openai_server/features/test_reasoning_content.py
Normal file
343
test/srt/openai_server/features/test_reasoning_content.py
Normal file
@@ -0,0 +1,343 @@
|
||||
"""
|
||||
Usage:
|
||||
python3 -m unittest openai_server.features.test_reasoning_content.TestReasoningContentAPI.test_streaming_separate_reasoning_false
|
||||
python3 -m unittest openai_server.features.test_reasoning_content.TestReasoningContentAPI.test_streaming_separate_reasoning_true
|
||||
python3 -m unittest openai_server.features.test_reasoning_content.TestReasoningContentAPI.test_streaming_separate_reasoning_true_stream_reasoning_false
|
||||
python3 -m unittest openai_server.features.test_reasoning_content.TestReasoningContentAPI.test_nonstreaming_separate_reasoning_false
|
||||
python3 -m unittest openai_server.features.test_reasoning_content.TestReasoningContentAPI.test_nonstreaming_separate_reasoning_true
|
||||
python3 -m unittest openai_server.features.test_reasoning_content.TestReasoningContentStartup.test_nonstreaming
|
||||
python3 -m unittest openai_server.features.test_reasoning_content.TestReasoningContentStartup.test_streaming
|
||||
"""
|
||||
|
||||
import json
|
||||
import unittest
|
||||
|
||||
import openai
|
||||
import requests
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_REASONING_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestReasoningContentAPI(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_REASONING_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-1234"
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
other_args=[
|
||||
"--reasoning-parser",
|
||||
"deepseek-r1",
|
||||
],
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_streaming_separate_reasoning_false(self):
|
||||
# Test streaming with separate_reasoning=False, reasoning_content should be empty
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is 1+3?",
|
||||
}
|
||||
],
|
||||
"max_tokens": 100,
|
||||
"stream": True,
|
||||
"extra_body": {"separate_reasoning": False},
|
||||
}
|
||||
response = client.chat.completions.create(**payload)
|
||||
|
||||
reasoning_content = ""
|
||||
content = ""
|
||||
for chunk in response:
|
||||
if chunk.choices[0].delta.content:
|
||||
content += chunk.choices[0].delta.content
|
||||
elif chunk.choices[0].delta.reasoning_content:
|
||||
reasoning_content += chunk.choices[0].delta.reasoning_content
|
||||
|
||||
assert len(reasoning_content) == 0
|
||||
assert len(content) > 0
|
||||
|
||||
def test_streaming_separate_reasoning_true(self):
|
||||
# Test streaming with separate_reasoning=True, reasoning_content should not be empty
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is 1+3?",
|
||||
}
|
||||
],
|
||||
"max_tokens": 100,
|
||||
"stream": True,
|
||||
"extra_body": {"separate_reasoning": True},
|
||||
}
|
||||
response = client.chat.completions.create(**payload)
|
||||
|
||||
reasoning_content = ""
|
||||
content = ""
|
||||
for chunk in response:
|
||||
if chunk.choices[0].delta.content:
|
||||
content += chunk.choices[0].delta.content
|
||||
elif chunk.choices[0].delta.reasoning_content:
|
||||
reasoning_content += chunk.choices[0].delta.reasoning_content
|
||||
|
||||
assert len(reasoning_content) > 0
|
||||
assert len(content) > 0
|
||||
|
||||
def test_streaming_separate_reasoning_true_stream_reasoning_false(self):
|
||||
# Test streaming with separate_reasoning=True, reasoning_content should not be empty
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is 1+3?",
|
||||
}
|
||||
],
|
||||
"max_tokens": 100,
|
||||
"stream": True,
|
||||
"extra_body": {"separate_reasoning": True, "stream_reasoning": False},
|
||||
}
|
||||
response = client.chat.completions.create(**payload)
|
||||
|
||||
reasoning_content = ""
|
||||
content = ""
|
||||
first_chunk = False
|
||||
for chunk in response:
|
||||
if chunk.choices[0].delta.reasoning_content:
|
||||
reasoning_content = chunk.choices[0].delta.reasoning_content
|
||||
first_chunk = True
|
||||
if chunk.choices[0].delta.content:
|
||||
content += chunk.choices[0].delta.content
|
||||
if not first_chunk:
|
||||
reasoning_content = chunk.choices[0].delta.reasoning_content
|
||||
first_chunk = True
|
||||
if not first_chunk:
|
||||
assert (
|
||||
not chunk.choices[0].delta.reasoning_content
|
||||
or len(chunk.choices[0].delta.reasoning_content) == 0
|
||||
)
|
||||
assert len(reasoning_content) > 0
|
||||
assert len(content) > 0
|
||||
|
||||
def test_nonstreaming_separate_reasoning_false(self):
|
||||
# Test non-streaming with separate_reasoning=False, reasoning_content should be empty
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is 1+3?",
|
||||
}
|
||||
],
|
||||
"max_tokens": 100,
|
||||
"extra_body": {"separate_reasoning": False},
|
||||
}
|
||||
response = client.chat.completions.create(**payload)
|
||||
|
||||
assert (
|
||||
not response.choices[0].message.reasoning_content
|
||||
or len(response.choices[0].message.reasoning_content) == 0
|
||||
)
|
||||
assert len(response.choices[0].message.content) > 0
|
||||
|
||||
def test_nonstreaming_separate_reasoning_true(self):
|
||||
# Test non-streaming with separate_reasoning=True, reasoning_content should not be empty
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is 1+3?",
|
||||
}
|
||||
],
|
||||
"max_tokens": 100,
|
||||
"extra_body": {"separate_reasoning": True},
|
||||
}
|
||||
response = client.chat.completions.create(**payload)
|
||||
|
||||
assert len(response.choices[0].message.reasoning_content) > 0
|
||||
assert len(response.choices[0].message.content) > 0
|
||||
|
||||
|
||||
class TestReasoningContentWithoutParser(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_REASONING_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-1234"
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
other_args=[], # No reasoning parser
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_streaming_separate_reasoning_false(self):
|
||||
# Test streaming with separate_reasoning=False, reasoning_content should be empty
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is 1+3?",
|
||||
}
|
||||
],
|
||||
"max_tokens": 100,
|
||||
"stream": True,
|
||||
"extra_body": {"separate_reasoning": False},
|
||||
}
|
||||
response = client.chat.completions.create(**payload)
|
||||
|
||||
reasoning_content = ""
|
||||
content = ""
|
||||
for chunk in response:
|
||||
if chunk.choices[0].delta.content:
|
||||
content += chunk.choices[0].delta.content
|
||||
elif chunk.choices[0].delta.reasoning_content:
|
||||
reasoning_content += chunk.choices[0].delta.reasoning_content
|
||||
|
||||
assert len(reasoning_content) == 0
|
||||
assert len(content) > 0
|
||||
|
||||
def test_streaming_separate_reasoning_true(self):
|
||||
# Test streaming with separate_reasoning=True, reasoning_content should not be empty
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is 1+3?",
|
||||
}
|
||||
],
|
||||
"max_tokens": 100,
|
||||
"stream": True,
|
||||
"extra_body": {"separate_reasoning": True},
|
||||
}
|
||||
response = client.chat.completions.create(**payload)
|
||||
|
||||
reasoning_content = ""
|
||||
content = ""
|
||||
for chunk in response:
|
||||
if chunk.choices[0].delta.content:
|
||||
content += chunk.choices[0].delta.content
|
||||
elif chunk.choices[0].delta.reasoning_content:
|
||||
reasoning_content += chunk.choices[0].delta.reasoning_content
|
||||
|
||||
assert len(reasoning_content) == 0
|
||||
assert len(content) > 0
|
||||
|
||||
def test_streaming_separate_reasoning_true_stream_reasoning_false(self):
|
||||
# Test streaming with separate_reasoning=True, reasoning_content should not be empty
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is 1+3?",
|
||||
}
|
||||
],
|
||||
"max_tokens": 100,
|
||||
"stream": True,
|
||||
"extra_body": {"separate_reasoning": True, "stream_reasoning": False},
|
||||
}
|
||||
response = client.chat.completions.create(**payload)
|
||||
|
||||
reasoning_content = ""
|
||||
content = ""
|
||||
first_chunk = False
|
||||
for chunk in response:
|
||||
if chunk.choices[0].delta.reasoning_content:
|
||||
reasoning_content = chunk.choices[0].delta.reasoning_content
|
||||
first_chunk = True
|
||||
if chunk.choices[0].delta.content:
|
||||
content += chunk.choices[0].delta.content
|
||||
if not first_chunk:
|
||||
reasoning_content = chunk.choices[0].delta.reasoning_content
|
||||
first_chunk = True
|
||||
if not first_chunk:
|
||||
assert (
|
||||
not chunk.choices[0].delta.reasoning_content
|
||||
or len(chunk.choices[0].delta.reasoning_content) == 0
|
||||
)
|
||||
assert not reasoning_content or len(reasoning_content) == 0
|
||||
assert len(content) > 0
|
||||
|
||||
def test_nonstreaming_separate_reasoning_false(self):
|
||||
# Test non-streaming with separate_reasoning=False, reasoning_content should be empty
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is 1+3?",
|
||||
}
|
||||
],
|
||||
"max_tokens": 100,
|
||||
"extra_body": {"separate_reasoning": False},
|
||||
}
|
||||
response = client.chat.completions.create(**payload)
|
||||
|
||||
assert (
|
||||
not response.choices[0].message.reasoning_content
|
||||
or len(response.choices[0].message.reasoning_content) == 0
|
||||
)
|
||||
assert len(response.choices[0].message.content) > 0
|
||||
|
||||
def test_nonstreaming_separate_reasoning_true(self):
|
||||
# Test non-streaming with separate_reasoning=True, reasoning_content should not be empty
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is 1+3?",
|
||||
}
|
||||
],
|
||||
"max_tokens": 100,
|
||||
"extra_body": {"separate_reasoning": True},
|
||||
}
|
||||
response = client.chat.completions.create(**payload)
|
||||
|
||||
assert (
|
||||
not response.choices[0].message.reasoning_content
|
||||
or len(response.choices[0].message.reasoning_content) == 0
|
||||
)
|
||||
assert len(response.choices[0].message.content) > 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
0
test/srt/openai_server/function_call/__init__.py
Normal file
0
test/srt/openai_server/function_call/__init__.py
Normal file
@@ -0,0 +1,589 @@
|
||||
import json
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import openai
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestOpenAIServerFunctionCalling(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
# Replace with the model name needed for testing; if not required, reuse DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
|
||||
# Start the local OpenAI Server. If necessary, you can add other parameters such as --enable-tools.
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
other_args=[
|
||||
# If your server needs extra parameters to test function calling, please add them here.
|
||||
"--tool-call-parser",
|
||||
"llama3",
|
||||
],
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
cls.tokenizer = get_tokenizer(cls.model)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_function_calling_format(self):
|
||||
"""
|
||||
Test: Whether the function call format returned by the AI is correct.
|
||||
When returning a tool call, message.content should be None, and tool_calls should be a list.
|
||||
"""
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "add",
|
||||
"description": "Compute the sum of two numbers",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {
|
||||
"type": "int",
|
||||
"description": "A number",
|
||||
},
|
||||
"b": {
|
||||
"type": "int",
|
||||
"description": "A number",
|
||||
},
|
||||
},
|
||||
"required": ["a", "b"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
messages = [{"role": "user", "content": "Compute (3+5)"}]
|
||||
response = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
temperature=0.8,
|
||||
top_p=0.8,
|
||||
stream=False,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
tool_calls = response.choices[0].message.tool_calls
|
||||
|
||||
assert (
|
||||
isinstance(tool_calls, list) and len(tool_calls) > 0
|
||||
), "tool_calls should be a non-empty list"
|
||||
|
||||
function_name = tool_calls[0].function.name
|
||||
assert function_name == "add", "Function name should be 'add'"
|
||||
|
||||
def test_function_calling_streaming_simple(self):
|
||||
"""
|
||||
Test: Whether the function name can be correctly recognized in streaming mode.
|
||||
- Expect a function call to be found, and the function name to be correct.
|
||||
- Verify that streaming mode returns at least multiple chunks.
|
||||
"""
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The city to find the weather for",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description": "Weather unit (celsius or fahrenheit)",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["city", "unit"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
messages = [{"role": "user", "content": "What is the temperature in Paris?"}]
|
||||
|
||||
response_stream = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
temperature=0.8,
|
||||
top_p=0.8,
|
||||
stream=True,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
chunks = list(response_stream)
|
||||
self.assertTrue(len(chunks) > 0, "Streaming should return at least one chunk")
|
||||
|
||||
found_function_name = False
|
||||
for chunk in chunks:
|
||||
choice = chunk.choices[0]
|
||||
# Check whether the current chunk contains tool_calls
|
||||
if choice.delta.tool_calls:
|
||||
tool_call = choice.delta.tool_calls[0]
|
||||
if tool_call.function.name:
|
||||
self.assertEqual(
|
||||
tool_call.function.name,
|
||||
"get_current_weather",
|
||||
"Function name should be 'get_current_weather'",
|
||||
)
|
||||
found_function_name = True
|
||||
break
|
||||
|
||||
self.assertTrue(
|
||||
found_function_name,
|
||||
"Target function name 'get_current_weather' was not found in the streaming chunks",
|
||||
)
|
||||
|
||||
def test_function_calling_streaming_args_parsing(self):
|
||||
"""
|
||||
Test: Whether the function call arguments returned in streaming mode can be correctly concatenated into valid JSON.
|
||||
- The user request requires multiple parameters.
|
||||
- AI may return the arguments in chunks that need to be concatenated.
|
||||
"""
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "add",
|
||||
"description": "Compute the sum of two integers",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {
|
||||
"type": "integer",
|
||||
"description": "First integer",
|
||||
},
|
||||
"b": {
|
||||
"type": "integer",
|
||||
"description": "Second integer",
|
||||
},
|
||||
},
|
||||
"required": ["a", "b"],
|
||||
},
|
||||
"strict": True, # Llama-3.2-1B is flaky in tool call. It won't always respond with parameters unless we set strict.
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Please sum 5 and 7, just call the function."}
|
||||
]
|
||||
|
||||
response_stream = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
temperature=0.9,
|
||||
top_p=0.9,
|
||||
stream=True,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
argument_fragments = []
|
||||
function_name = None
|
||||
for chunk in response_stream:
|
||||
choice = chunk.choices[0]
|
||||
if choice.delta.tool_calls:
|
||||
tool_call = choice.delta.tool_calls[0]
|
||||
# Record the function name on first occurrence
|
||||
function_name = tool_call.function.name or function_name
|
||||
# In case of multiple chunks, JSON fragments may need to be concatenated
|
||||
if tool_call.function.arguments is not None:
|
||||
argument_fragments.append(tool_call.function.arguments)
|
||||
|
||||
self.assertEqual(function_name, "add", "Function name should be 'add'")
|
||||
joined_args = "".join(argument_fragments)
|
||||
self.assertTrue(
|
||||
len(joined_args) > 0,
|
||||
"No parameter fragments were returned in the function call",
|
||||
)
|
||||
|
||||
# Check whether the concatenated JSON is valid
|
||||
try:
|
||||
args_obj = json.loads(joined_args)
|
||||
except json.JSONDecodeError:
|
||||
self.fail(
|
||||
"The concatenated tool call arguments are not valid JSON, parsing failed"
|
||||
)
|
||||
|
||||
self.assertIn("a", args_obj, "Missing parameter 'a'")
|
||||
self.assertIn("b", args_obj, "Missing parameter 'b'")
|
||||
self.assertEqual(str(args_obj["a"]), "5", "Parameter a should be 5")
|
||||
self.assertEqual(str(args_obj["b"]), "7", "Parameter b should be 7")
|
||||
|
||||
def test_function_call_strict(self):
|
||||
"""
|
||||
Test: Whether the strict mode of function calling works as expected.
|
||||
- When strict mode is enabled, the AI should not return a function call if the function name is not recognized.
|
||||
"""
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "sub",
|
||||
"description": "Compute the difference of two integers",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"int_a": {
|
||||
"type": "integer",
|
||||
"description": "First integer",
|
||||
},
|
||||
"int_b": {
|
||||
"type": "integer",
|
||||
"description": "Second integer",
|
||||
},
|
||||
},
|
||||
"required": ["int_a", "int_b"],
|
||||
},
|
||||
"strict": True,
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Please compute 5 - 7, using your tool."}
|
||||
]
|
||||
response = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
temperature=0.8,
|
||||
top_p=0.8,
|
||||
stream=False,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
tool_calls = response.choices[0].message.tool_calls
|
||||
function_name = tool_calls[0].function.name
|
||||
arguments = tool_calls[0].function.arguments
|
||||
args_obj = json.loads(arguments)
|
||||
|
||||
self.assertEqual(function_name, "sub", "Function name should be 'sub'")
|
||||
self.assertEqual(str(args_obj["int_a"]), "5", "Parameter int_a should be 5")
|
||||
self.assertEqual(str(args_obj["int_b"]), "7", "Parameter int_b should be 7")
|
||||
|
||||
def test_function_call_required(self):
|
||||
"""
|
||||
Test: Whether tool_choice: "required" works as expected
|
||||
- When tool_choice == "required", the model should return one or more tool_calls.
|
||||
"""
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "sub",
|
||||
"description": "Compute the difference of two integers",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"int_a": {
|
||||
"type": "integer",
|
||||
"description": "First integer",
|
||||
},
|
||||
"int_b": {
|
||||
"type": "integer",
|
||||
"description": "Second integer",
|
||||
},
|
||||
},
|
||||
"required": ["int_a", "int_b"],
|
||||
},
|
||||
"strict": True,
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "use this to get latest weather information for a city given its name",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "name of the city to get weather for",
|
||||
}
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
messages = [{"role": "user", "content": "What is the capital of France?"}]
|
||||
response = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
temperature=0.8,
|
||||
top_p=0.8,
|
||||
stream=False,
|
||||
tools=tools,
|
||||
tool_choice="required",
|
||||
)
|
||||
|
||||
tool_calls = response.choices[0].message.tool_calls
|
||||
self.assertIsNotNone(tool_calls, "No tool_calls in the response")
|
||||
function_name = tool_calls[0].function.name
|
||||
arguments = tool_calls[0].function.arguments
|
||||
args_obj = json.loads(arguments)
|
||||
|
||||
self.assertEqual(
|
||||
function_name,
|
||||
"get_weather",
|
||||
f"Function name should be 'get_weather', got: {function_name}",
|
||||
)
|
||||
self.assertIn(
|
||||
"city", args_obj, f"Function arguments should have 'city', got: {args_obj}"
|
||||
)
|
||||
|
||||
# Make the test more robust by checking type and accepting valid responses
|
||||
city_value = args_obj["city"]
|
||||
self.assertIsInstance(
|
||||
city_value,
|
||||
str,
|
||||
f"Parameter city should be a string, got: {type(city_value)}",
|
||||
)
|
||||
self.assertTrue(
|
||||
"Paris" in city_value or "France" in city_value,
|
||||
f"Parameter city should contain either 'Paris' or 'France', got: {city_value}",
|
||||
)
|
||||
|
||||
def test_function_call_specific(self):
|
||||
"""
|
||||
Test: Whether tool_choice: ToolChoice works as expected
|
||||
- When tool_choice is a specific ToolChoice, the model should return one or more tool_calls.
|
||||
"""
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "sub",
|
||||
"description": "Compute the difference of two integers",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"int_a": {
|
||||
"type": "integer",
|
||||
"description": "First integer",
|
||||
},
|
||||
"int_b": {
|
||||
"type": "integer",
|
||||
"description": "Second integer",
|
||||
},
|
||||
},
|
||||
"required": ["int_a", "int_b"],
|
||||
},
|
||||
"strict": True,
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "use this to get latest weather information for a city given its name",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "name of the city to get weather for",
|
||||
}
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
messages = [{"role": "user", "content": "What is the capital of France?"}]
|
||||
response = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
temperature=0.8,
|
||||
top_p=0.8,
|
||||
stream=False,
|
||||
tools=tools,
|
||||
tool_choice={"type": "function", "function": {"name": "get_weather"}},
|
||||
)
|
||||
|
||||
tool_calls = response.choices[0].message.tool_calls
|
||||
self.assertIsNotNone(tool_calls, "No tool_calls in the response")
|
||||
function_name = tool_calls[0].function.name
|
||||
arguments = tool_calls[0].function.arguments
|
||||
args_obj = json.loads(arguments)
|
||||
|
||||
self.assertEqual(
|
||||
function_name, "get_weather", "Function name should be 'get_weather'"
|
||||
)
|
||||
self.assertIn("city", args_obj, "Function arguments should have 'city'")
|
||||
|
||||
|
||||
class TestOpenAIPythonicFunctionCalling(CustomTestCase):
|
||||
PYTHONIC_TOOLS = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather for a given location.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The name of the city or location.",
|
||||
}
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_tourist_attractions",
|
||||
"description": "Get a list of top tourist attractions for a given city.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The name of the city to find attractions for.",
|
||||
}
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
PYTHONIC_MESSAGES = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are a travel assistant. "
|
||||
"When asked to call functions, ALWAYS respond ONLY with a python list of function calls, "
|
||||
"using this format: [func_name1(param1=value1, param2=value2), func_name2(param=value)]. "
|
||||
"Do NOT use JSON, do NOT use variables, do NOT use any other format. "
|
||||
"Here is an example:\n"
|
||||
'[get_weather(location="Paris"), get_tourist_attractions(city="Paris")]'
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"I'm planning a trip to Tokyo next week. What's the weather like and what are some top tourist attractions? "
|
||||
"Propose parallel tool calls at once, using the python list of function calls format as shown above."
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
other_args=[
|
||||
"--tool-call-parser",
|
||||
"pythonic",
|
||||
],
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
cls.tokenizer = get_tokenizer(cls.model)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_pythonic_tool_call_prompt(self):
|
||||
"""
|
||||
Test: Explicit prompt for pythonic tool call format without chat template.
|
||||
"""
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
response = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=self.PYTHONIC_MESSAGES,
|
||||
tools=self.PYTHONIC_TOOLS,
|
||||
temperature=0.1,
|
||||
stream=False,
|
||||
)
|
||||
tool_calls = response.choices[0].message.tool_calls
|
||||
self.assertIsInstance(tool_calls, list, "No tool_calls found")
|
||||
self.assertGreaterEqual(len(tool_calls), 1)
|
||||
names = [tc.function.name for tc in tool_calls]
|
||||
self.assertTrue(
|
||||
"get_weather" in names or "get_tourist_attractions" in names,
|
||||
f"Function name '{names}' should container either 'get_weather' or 'get_tourist_attractions'",
|
||||
)
|
||||
|
||||
def test_pythonic_tool_call_streaming(self):
|
||||
"""
|
||||
Test: Streaming pythonic tool call format; assert tool_call index is present.
|
||||
"""
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
response_stream = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=self.PYTHONIC_MESSAGES,
|
||||
tools=self.PYTHONIC_TOOLS,
|
||||
temperature=0.1,
|
||||
stream=True,
|
||||
)
|
||||
found_tool_calls = False
|
||||
found_index = False
|
||||
found_names = set()
|
||||
for chunk in response_stream:
|
||||
choice = chunk.choices[0]
|
||||
if getattr(choice.delta, "tool_calls", None):
|
||||
found_tool_calls = True
|
||||
tool_call = choice.delta.tool_calls[0]
|
||||
if hasattr(tool_call, "index") or (
|
||||
isinstance(tool_call, dict) and "index" in tool_call
|
||||
):
|
||||
found_index = True
|
||||
found_names.add(str(tool_call.function.name))
|
||||
|
||||
self.assertTrue(found_tool_calls, "No tool_calls found in streaming response")
|
||||
self.assertTrue(found_index, "No index field found in any streamed tool_call")
|
||||
self.assertTrue(
|
||||
"get_weather" in found_names or "get_tourist_attractions" in found_names,
|
||||
f"Function name '{found_names}' should container either 'get_weather' or 'get_tourist_attractions'",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
524
test/srt/openai_server/function_call/test_tool_choice.py
Normal file
524
test/srt/openai_server/function_call/test_tool_choice.py
Normal file
@@ -0,0 +1,524 @@
|
||||
"""
|
||||
Test script for tool_choice functionality in SGLang
|
||||
Tests: required, auto, and specific function choices in both streaming and non-streaming modes
|
||||
|
||||
# To run the tests, use the following command:
|
||||
#
|
||||
# python3 -m unittest openai_server.function_call.test_tool_choice
|
||||
"""
|
||||
|
||||
import json
|
||||
import unittest
|
||||
|
||||
import openai
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestToolChoiceLlama32(CustomTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
# Mark flaky tests for this model
|
||||
cls.flaky_tests = {
|
||||
"test_multi_tool_scenario_auto",
|
||||
"test_multi_tool_scenario_required",
|
||||
}
|
||||
|
||||
# Use a model that supports function calling
|
||||
cls.model = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
|
||||
# Start the local OpenAI Server with tool calling support
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
other_args=[
|
||||
"--tool-call-parser",
|
||||
"llama3", # Default parser for the test model
|
||||
],
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
cls.tokenizer = get_tokenizer(cls.model)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def setUp(self):
|
||||
self.client = openai.Client(base_url=self.base_url, api_key=self.api_key)
|
||||
self.model_name = self.client.models.list().data[0].id
|
||||
|
||||
def _is_flaky_test(self):
|
||||
"""Check if the current test is marked as flaky for this class"""
|
||||
return (
|
||||
hasattr(self.__class__, "flaky_tests")
|
||||
and self._testMethodName in self.__class__.flaky_tests
|
||||
)
|
||||
|
||||
def get_test_tools(self):
|
||||
"""Get the test tools for function calling"""
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "use this to get latest weather information for a city given its name",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "name of the city to get weather for",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_pokemon_info",
|
||||
"description": "get detailed information about a pokemon given its name",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "name of the pokemon to get info for",
|
||||
}
|
||||
},
|
||||
"required": ["name"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "make_next_step_decision",
|
||||
"description": "You will be given a trace of thinking process in the following format.\n\nQuestion: the input question you must answer\nTOOL: think about what to do, and choose a tool to use ONLY IF there are defined tools. \n You should never call the same tool with the same input twice in a row.\n If the previous conversation history already contains the information that can be retrieved from the tool, you should not call the tool again.\nOBSERVATION: the result of the tool call, NEVER include this in your response, this information will be provided\n... (this TOOL/OBSERVATION can repeat N times)\nANSWER: If you know the answer to the original question, require for more information,\n or you don't know the answer and there are no defined tools or all available tools are not helpful, respond with the answer without mentioning anything else.\n If the previous conversation history already contains the answer, respond with the answer right away.\n\n If no tools are configured, naturally mention this limitation while still being helpful. Briefly note that adding tools in the agent configuration would expand capabilities.\n\nYour task is to respond with the next step to take, based on the traces, \nor answer the question if you have enough information.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"decision": {
|
||||
"type": "string",
|
||||
"description": 'The next step to take, it must be either "TOOL" or "ANSWER". If the previous conversation history already contains the information that can be retrieved from the tool, you should not call the tool again. If there are no defined tools, you should not return "TOOL" in your response.',
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": 'The content of the next step. If the decision is "TOOL", this should be a short and concise reasoning of why you chose the tool, MUST include the tool name. If the decision is "ANSWER", this should be the answer to the question. If no tools are available, integrate this limitation conversationally without sounding scripted.',
|
||||
},
|
||||
},
|
||||
"required": ["decision", "content"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
def get_test_messages(self):
|
||||
"""Get test messages that should trigger tool usage"""
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Answer the following questions as best you can:\n\nYou will be given a trace of thinking process in the following format.\n\nQuestion: the input question you must answer\nTOOL: think about what to do, and choose a tool to use ONLY IF there are defined tools\nOBSERVATION: the result of the tool call or the observation of the current task, NEVER include this in your response, this information will be provided\n... (this TOOL/OBSERVATION can repeat N times)\nANSWER: If you know the answer to the original question, require for more information, \nif the previous conversation history already contains the answer, \nor you don't know the answer and there are no defined tools or all available tools are not helpful, respond with the answer without mentioning anything else.\nYou may use light Markdown formatting to improve clarity (e.g. lists, **bold**, *italics*), but keep it minimal and unobtrusive.\n\nYour task is to respond with the next step to take, based on the traces, \nor answer the question if you have enough information.\n\nQuestion: what is the weather in top 5 populated cities in the US?\n\nTraces:\n\n\nThese are some additional instructions that you should follow:",
|
||||
}
|
||||
]
|
||||
|
||||
def get_travel_tools(self):
|
||||
"""Get tools for travel assistant scenario that should trigger multiple tool calls"""
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather for a given location.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The name of the city or location.",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_tourist_attractions",
|
||||
"description": "Get a list of top tourist attractions for a given city.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The name of the city to find attractions for.",
|
||||
}
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
def get_travel_messages(self):
|
||||
"""Get travel assistant messages that should trigger multiple tool calls"""
|
||||
return [
|
||||
{
|
||||
"content": "You are a travel assistant providing real-time weather updates and top tourist attractions.",
|
||||
"role": "system",
|
||||
},
|
||||
{
|
||||
"content": "I'm planning a trip to Tokyo next week. What's the weather like? What are the most amazing sights?",
|
||||
"role": "user",
|
||||
},
|
||||
]
|
||||
|
||||
def test_tool_choice_auto_non_streaming(self):
|
||||
"""Test tool_choice='auto' in non-streaming mode"""
|
||||
tools = self.get_test_tools()
|
||||
messages = self.get_test_messages()
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
max_tokens=400,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
stream=False,
|
||||
)
|
||||
|
||||
self.assertIsNotNone(response.choices[0].message)
|
||||
# With auto, tool calls are optional
|
||||
|
||||
def test_tool_choice_auto_streaming(self):
|
||||
"""Test tool_choice='auto' in streaming mode"""
|
||||
tools = self.get_test_tools()
|
||||
messages = self.get_test_messages()
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
max_tokens=400,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Collect streaming response
|
||||
content_chunks = []
|
||||
tool_call_chunks = []
|
||||
|
||||
for chunk in response:
|
||||
if chunk.choices[0].delta.content:
|
||||
content_chunks.append(chunk.choices[0].delta.content)
|
||||
elif chunk.choices[0].delta.tool_calls:
|
||||
tool_call_chunks.extend(chunk.choices[0].delta.tool_calls)
|
||||
|
||||
# Should complete without errors
|
||||
self.assertIsInstance(content_chunks, list)
|
||||
self.assertIsInstance(tool_call_chunks, list)
|
||||
|
||||
def test_tool_choice_required_non_streaming(self):
|
||||
"""Test tool_choice='required' in non-streaming mode"""
|
||||
tools = self.get_test_tools()
|
||||
messages = self.get_test_messages()
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
max_tokens=400,
|
||||
temperature=0.2,
|
||||
tools=tools,
|
||||
tool_choice="required",
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# With required, we should get tool calls
|
||||
tool_calls = response.choices[0].message.tool_calls
|
||||
self.assertIsNotNone(tool_calls)
|
||||
self.assertGreater(len(tool_calls), 0)
|
||||
|
||||
def test_tool_choice_required_streaming(self):
|
||||
"""Test tool_choice='required' in streaming mode"""
|
||||
tools = self.get_test_tools()
|
||||
messages = self.get_test_messages()
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
max_tokens=400,
|
||||
tools=tools,
|
||||
tool_choice="required",
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Collect streaming response
|
||||
tool_call_chunks = []
|
||||
|
||||
for chunk in response:
|
||||
if chunk.choices[0].delta.tool_calls:
|
||||
tool_call_chunks.extend(chunk.choices[0].delta.tool_calls)
|
||||
|
||||
# With required, we should get tool call chunks
|
||||
self.assertGreater(len(tool_call_chunks), 0)
|
||||
|
||||
def test_tool_choice_specific_function_non_streaming(self):
|
||||
"""Test tool_choice with specific function in non-streaming mode"""
|
||||
tools = self.get_test_tools()
|
||||
messages = self.get_test_messages()
|
||||
|
||||
tool_choice = {"type": "function", "function": {"name": "get_weather"}}
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
max_tokens=200,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Should call the specific function
|
||||
tool_calls = response.choices[0].message.tool_calls
|
||||
self.assertIsNotNone(tool_calls)
|
||||
# Our messages ask the top 5 populated cities in the US, so the model could get 5 tool calls
|
||||
self.assertGreaterEqual(len(tool_calls), 1)
|
||||
for tool_call in tool_calls:
|
||||
self.assertEqual(tool_call.function.name, "get_weather")
|
||||
|
||||
def test_tool_choice_specific_function_streaming(self):
|
||||
"""Test tool_choice with specific function in streaming mode"""
|
||||
tools = self.get_test_tools()
|
||||
messages = self.get_test_messages()
|
||||
|
||||
tool_choice = {"type": "function", "function": {"name": "get_weather"}}
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
max_tokens=200,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Collect streaming response
|
||||
tool_call_chunks = []
|
||||
|
||||
for chunk in response:
|
||||
if chunk.choices[0].delta.tool_calls:
|
||||
tool_call_chunks.extend(chunk.choices[0].delta.tool_calls)
|
||||
|
||||
# Should get tool call chunks for the specific function
|
||||
self.assertGreater(len(tool_call_chunks), 0)
|
||||
|
||||
# Find function name in chunks
|
||||
found_name = None
|
||||
for chunk in tool_call_chunks:
|
||||
if chunk.function and chunk.function.name:
|
||||
found_name = chunk.function.name
|
||||
break
|
||||
|
||||
self.assertEqual(found_name, "get_weather")
|
||||
|
||||
def test_multi_tool_scenario_auto(self):
|
||||
"""Test multi-tool scenario with tool_choice='auto'"""
|
||||
tools = self.get_travel_tools()
|
||||
messages = self.get_travel_messages()
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
max_tokens=400,
|
||||
temperature=0.2,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Should complete without errors
|
||||
self.assertIsNotNone(response.choices[0].message)
|
||||
|
||||
tool_calls = response.choices[0].message.tool_calls
|
||||
expected_functions = {"get_weather", "get_tourist_attractions"}
|
||||
|
||||
if self._is_flaky_test():
|
||||
# For flaky tests, just verify all called functions are available tools
|
||||
if tool_calls:
|
||||
available_names = [tool["function"]["name"] for tool in tools]
|
||||
for call in tool_calls:
|
||||
self.assertIn(call.function.name, available_names)
|
||||
else:
|
||||
# For non-flaky tests, enforce strict requirements
|
||||
self.assertIsNotNone(tool_calls, "Expected tool calls but got none")
|
||||
self.assertEqual(
|
||||
len(tool_calls), 2, f"Expected 2 tool calls, got {len(tool_calls)}"
|
||||
)
|
||||
|
||||
called_functions = {call.function.name for call in tool_calls}
|
||||
self.assertEqual(
|
||||
called_functions,
|
||||
expected_functions,
|
||||
f"Expected functions {expected_functions}, got {called_functions}",
|
||||
)
|
||||
|
||||
def test_multi_tool_scenario_required(self):
|
||||
"""Test multi-tool scenario with tool_choice='required'"""
|
||||
tools = self.get_travel_tools()
|
||||
messages = self.get_travel_messages()
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
max_tokens=400,
|
||||
temperature=0.2,
|
||||
tools=tools,
|
||||
tool_choice="required",
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# With required, we should get at least one tool call
|
||||
tool_calls = response.choices[0].message.tool_calls
|
||||
self.assertIsNotNone(tool_calls)
|
||||
self.assertGreater(len(tool_calls), 0)
|
||||
|
||||
# Verify all called functions are available tools
|
||||
available_names = [tool["function"]["name"] for tool in tools]
|
||||
expected_functions = {"get_weather", "get_tourist_attractions"}
|
||||
|
||||
if self._is_flaky_test():
|
||||
# For flaky tests, just ensure basic functionality works
|
||||
self.assertGreater(
|
||||
len(tool_calls),
|
||||
0,
|
||||
f"Expected at least 1 tool call, got {len(tool_calls)}",
|
||||
)
|
||||
for call in tool_calls:
|
||||
self.assertIn(call.function.name, available_names)
|
||||
else:
|
||||
# For non-flaky tests, enforce strict requirements
|
||||
self.assertEqual(
|
||||
len(tool_calls), 2, f"Expected 2 tool calls, got {len(tool_calls)}"
|
||||
)
|
||||
|
||||
called_functions = {call.function.name for call in tool_calls}
|
||||
self.assertEqual(
|
||||
called_functions,
|
||||
expected_functions,
|
||||
f"Expected functions {expected_functions}, got {called_functions}",
|
||||
)
|
||||
|
||||
def test_error_handling_invalid_tool_choice(self):
|
||||
"""Test error handling for invalid tool_choice"""
|
||||
import logging
|
||||
from unittest.mock import patch
|
||||
|
||||
tools = self.get_test_tools()
|
||||
messages = self.get_test_messages()
|
||||
|
||||
# Test with invalid function name
|
||||
tool_choice = {"type": "function", "function": {"name": "nonexistent_function"}}
|
||||
|
||||
# The behavior could be either:
|
||||
# 1. Log a warning and continue (if fallback is implemented)
|
||||
# 2. Raise an exception (if strict validation is implemented)
|
||||
|
||||
# First try to capture any logging that might happen
|
||||
with patch("logging.warning") as mock_warning:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
max_tokens=200,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
self.assertIsNotNone(response.choices[0].message)
|
||||
|
||||
if mock_warning.called:
|
||||
warning_message = mock_warning.call_args[0][0]
|
||||
self.assertIn("nonexistent_function", warning_message)
|
||||
|
||||
|
||||
class TestToolChoiceQwen25(TestToolChoiceLlama32):
|
||||
"""Test tool_choice functionality with Qwen2.5 model"""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.flaky_tests = {
|
||||
"test_multi_tool_scenario_auto",
|
||||
"test_multi_tool_scenario_required",
|
||||
}
|
||||
|
||||
cls.model = "Qwen/Qwen2.5-7B-Instruct"
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
other_args=[
|
||||
"--tool-call-parser",
|
||||
"qwen25",
|
||||
],
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
cls.tokenizer = get_tokenizer(cls.model)
|
||||
|
||||
|
||||
class TestToolChoiceMistral(TestToolChoiceLlama32):
|
||||
"""Test tool_choice functionality with Mistral model"""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
# Mark flaky tests for this model
|
||||
cls.flaky_tests = {
|
||||
"test_multi_tool_scenario_auto",
|
||||
"test_multi_tool_scenario_required",
|
||||
}
|
||||
|
||||
cls.model = "mistralai/Mistral-7B-Instruct-v0.3"
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
other_args=[
|
||||
"--tool-call-parser",
|
||||
"mistral",
|
||||
],
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
cls.tokenizer = get_tokenizer(cls.model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
0
test/srt/openai_server/validation/__init__.py
Normal file
0
test/srt/openai_server/validation/__init__.py
Normal file
103
test/srt/openai_server/validation/test_large_max_new_tokens.py
Normal file
103
test/srt/openai_server/validation/test_large_max_new_tokens.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""
|
||||
python3 -m unittest openai_server.validation.test_large_max_new_tokens.TestLargeMaxNewTokens.test_chat_completion
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import unittest
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import openai
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
STDERR_FILENAME,
|
||||
STDOUT_FILENAME,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestLargeMaxNewTokens(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
|
||||
cls.stdout = open(STDOUT_FILENAME, "w")
|
||||
cls.stderr = open(STDERR_FILENAME, "w")
|
||||
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
other_args=(
|
||||
"--max-total-token",
|
||||
"1536",
|
||||
"--context-len",
|
||||
"8192",
|
||||
"--decode-log-interval",
|
||||
"2",
|
||||
),
|
||||
env={"SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION": "256", **os.environ},
|
||||
return_stdout_stderr=(cls.stdout, cls.stderr),
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
cls.stdout.close()
|
||||
cls.stderr.close()
|
||||
os.remove(STDOUT_FILENAME)
|
||||
os.remove(STDERR_FILENAME)
|
||||
|
||||
def run_chat_completion(self):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
response = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Please repeat the world 'hello' for 10000 times.",
|
||||
},
|
||||
],
|
||||
temperature=0,
|
||||
)
|
||||
return response
|
||||
|
||||
def test_chat_completion(self):
|
||||
num_requests = 4
|
||||
|
||||
futures = []
|
||||
with ThreadPoolExecutor(num_requests) as executor:
|
||||
# Send multiple requests
|
||||
for i in range(num_requests):
|
||||
futures.append(executor.submit(self.run_chat_completion))
|
||||
|
||||
# Ensure that they are running concurrently
|
||||
pt = 0
|
||||
while pt >= 0:
|
||||
time.sleep(5)
|
||||
lines = open(STDERR_FILENAME).readlines()
|
||||
for line in lines[pt:]:
|
||||
print(line, end="", flush=True)
|
||||
if f"#running-req: {num_requests}" in line:
|
||||
all_requests_running = True
|
||||
pt = -1
|
||||
break
|
||||
pt += 1
|
||||
|
||||
assert all_requests_running
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
140
test/srt/openai_server/validation/test_matched_stop.py
Normal file
140
test/srt/openai_server/validation/test_matched_stop.py
Normal file
@@ -0,0 +1,140 @@
|
||||
import json
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
MANY_NEW_TOKENS_PROMPT = """
|
||||
Please write an extremely detailed and vivid fantasy story, set in a world full of intricate magic systems, political intrigue, and complex characters.
|
||||
Ensure that you thoroughly describe every scene, character's motivations, and the environment. Include long, engaging dialogues and elaborate on the inner thoughts of the characters.
|
||||
Each section should be as comprehensive as possible to create a rich and immersive experience for the reader.
|
||||
The story should span multiple events, challenges, and character developments over time. Aim to make the story at least 3,000 words long.
|
||||
"""
|
||||
|
||||
|
||||
class TestMatchedStop(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=300,
|
||||
other_args=["--max-running-requests", "10"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def run_completions_generation(
|
||||
self,
|
||||
prompt=MANY_NEW_TOKENS_PROMPT,
|
||||
max_tokens=1,
|
||||
stop=None,
|
||||
finish_reason=None,
|
||||
matched_stop=None,
|
||||
):
|
||||
payload = {
|
||||
"prompt": prompt,
|
||||
"model": self.model,
|
||||
"temperature": 0,
|
||||
"top_p": 1,
|
||||
"max_tokens": max_tokens,
|
||||
}
|
||||
|
||||
if stop is not None:
|
||||
payload["stop"] = stop
|
||||
|
||||
response_completions = requests.post(
|
||||
self.base_url + "/v1/completions",
|
||||
json=payload,
|
||||
)
|
||||
print(json.dumps(response_completions.json()))
|
||||
print("=" * 100)
|
||||
|
||||
assert (
|
||||
response_completions.json()["choices"][0]["finish_reason"] == finish_reason
|
||||
)
|
||||
assert response_completions.json()["choices"][0]["matched_stop"] == matched_stop
|
||||
|
||||
def run_chat_completions_generation(
|
||||
self,
|
||||
prompt=MANY_NEW_TOKENS_PROMPT,
|
||||
max_tokens=1,
|
||||
stop=None,
|
||||
finish_reason=None,
|
||||
matched_stop=None,
|
||||
):
|
||||
chat_payload = {
|
||||
"model": self.model,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
"temperature": 0,
|
||||
"top_p": 1,
|
||||
"max_tokens": max_tokens,
|
||||
}
|
||||
|
||||
if stop is not None:
|
||||
chat_payload["stop"] = stop
|
||||
|
||||
response_chat = requests.post(
|
||||
self.base_url + "/v1/chat/completions",
|
||||
json=chat_payload,
|
||||
)
|
||||
print(json.dumps(response_chat.json()))
|
||||
print("=" * 100)
|
||||
|
||||
assert response_chat.json()["choices"][0]["finish_reason"] == finish_reason
|
||||
assert response_chat.json()["choices"][0]["matched_stop"] == matched_stop
|
||||
|
||||
def test_finish_stop_str(self):
|
||||
self.run_completions_generation(
|
||||
max_tokens=1000, stop="\n", finish_reason="stop", matched_stop="\n"
|
||||
)
|
||||
self.run_chat_completions_generation(
|
||||
max_tokens=1000, stop="\n", finish_reason="stop", matched_stop="\n"
|
||||
)
|
||||
|
||||
def test_finish_stop_eos(self):
|
||||
llama_format_prompt = """
|
||||
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||
You are a helpful assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
What is 2 + 2?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
"""
|
||||
eos_token_id = 128009
|
||||
self.run_completions_generation(
|
||||
prompt=llama_format_prompt,
|
||||
max_tokens=1000,
|
||||
finish_reason="stop",
|
||||
matched_stop=eos_token_id,
|
||||
)
|
||||
self.run_chat_completions_generation(
|
||||
prompt="What is 2 + 2?",
|
||||
max_tokens=1000,
|
||||
finish_reason="stop",
|
||||
matched_stop=eos_token_id,
|
||||
)
|
||||
|
||||
def test_finish_length(self):
|
||||
self.run_completions_generation(
|
||||
max_tokens=5, finish_reason="length", matched_stop=None
|
||||
)
|
||||
self.run_chat_completions_generation(
|
||||
max_tokens=5, finish_reason="length", matched_stop=None
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -0,0 +1,84 @@
|
||||
import openai
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestOpenAIServerIgnoreEOS(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_ignore_eos(self):
|
||||
"""
|
||||
Test that ignore_eos=True allows generation to continue beyond EOS token
|
||||
and reach the max_tokens limit.
|
||||
"""
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
max_tokens = 200
|
||||
|
||||
response_default = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Count from 1 to 20."},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=max_tokens,
|
||||
extra_body={"ignore_eos": False},
|
||||
)
|
||||
|
||||
response_ignore_eos = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Count from 1 to 20."},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=max_tokens,
|
||||
extra_body={"ignore_eos": True},
|
||||
)
|
||||
|
||||
default_tokens = len(
|
||||
self.tokenizer.encode(response_default.choices[0].message.content)
|
||||
)
|
||||
ignore_eos_tokens = len(
|
||||
self.tokenizer.encode(response_ignore_eos.choices[0].message.content)
|
||||
)
|
||||
|
||||
# Check if ignore_eos resulted in more tokens or exactly max_tokens
|
||||
# The ignore_eos response should either:
|
||||
# 1. Have more tokens than the default response (if default stopped at EOS before max_tokens)
|
||||
# 2. Have exactly max_tokens (if it reached the max_tokens limit)
|
||||
self.assertTrue(
|
||||
ignore_eos_tokens > default_tokens or ignore_eos_tokens >= max_tokens,
|
||||
f"ignore_eos did not generate more tokens: {ignore_eos_tokens} vs {default_tokens}",
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
response_ignore_eos.choices[0].finish_reason,
|
||||
"length",
|
||||
f"Expected finish_reason='length' for ignore_eos=True, got {response_ignore_eos.choices[0].finish_reason}",
|
||||
)
|
||||
@@ -0,0 +1,88 @@
|
||||
import unittest
|
||||
|
||||
import openai
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestRequestLengthValidation(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
|
||||
# Start server with auto truncate disabled
|
||||
cls.process = popen_launch_server(
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
other_args=("--max-total-tokens", "1000", "--context-length", "1000"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_input_length_longer_than_context_length(self):
|
||||
client = openai.Client(api_key=self.api_key, base_url=f"{self.base_url}/v1")
|
||||
|
||||
long_text = "hello " * 1200 # Will tokenize to more than context length
|
||||
|
||||
with self.assertRaises(openai.BadRequestError) as cm:
|
||||
client.chat.completions.create(
|
||||
model=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
messages=[
|
||||
{"role": "user", "content": long_text},
|
||||
],
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
self.assertIn("is longer than the model's context length", str(cm.exception))
|
||||
|
||||
def test_input_length_longer_than_maximum_allowed_length(self):
|
||||
client = openai.Client(api_key=self.api_key, base_url=f"{self.base_url}/v1")
|
||||
|
||||
long_text = "hello " * 999 # the maximum allowed length is 994 tokens
|
||||
|
||||
with self.assertRaises(openai.BadRequestError) as cm:
|
||||
client.chat.completions.create(
|
||||
model=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
messages=[
|
||||
{"role": "user", "content": long_text},
|
||||
],
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
self.assertIn("is longer than the model's context length", str(cm.exception))
|
||||
|
||||
def test_max_tokens_validation(self):
|
||||
client = openai.Client(api_key=self.api_key, base_url=f"{self.base_url}/v1")
|
||||
|
||||
long_text = "hello "
|
||||
|
||||
with self.assertRaises(openai.BadRequestError) as cm:
|
||||
client.chat.completions.create(
|
||||
model=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
messages=[
|
||||
{"role": "user", "content": long_text},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=1200,
|
||||
)
|
||||
|
||||
self.assertIn(
|
||||
"Requested token count exceeds the model's maximum context",
|
||||
str(cm.exception),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user