sglangv0.5.2 & support Qwen3-Next-80B-A3B-Instruct
This commit is contained in:
0
test/srt/openai_server/basic/__init__.py
Normal file
0
test/srt/openai_server/basic/__init__.py
Normal file
97
test/srt/openai_server/basic/test_openai_embedding.py
Normal file
97
test/srt/openai_server/basic/test_openai_embedding.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import unittest
|
||||
|
||||
import openai
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestOpenAIEmbedding(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
|
||||
# Configure embedding-specific args
|
||||
other_args = ["--is-embedding", "--enable-metrics"]
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
other_args=other_args,
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_embedding_single(self):
|
||||
"""Test single embedding request"""
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
response = client.embeddings.create(model=self.model, input="Hello world")
|
||||
self.assertEqual(len(response.data), 1)
|
||||
self.assertTrue(len(response.data[0].embedding) > 0)
|
||||
|
||||
def test_embedding_batch(self):
|
||||
"""Test batch embedding request"""
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
response = client.embeddings.create(
|
||||
model=self.model, input=["Hello world", "Test text"]
|
||||
)
|
||||
self.assertEqual(len(response.data), 2)
|
||||
self.assertTrue(len(response.data[0].embedding) > 0)
|
||||
self.assertTrue(len(response.data[1].embedding) > 0)
|
||||
|
||||
def test_embedding_single_batch_str(self):
|
||||
"""Test embedding with a List[str] and length equals to 1"""
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
response = client.embeddings.create(model=self.model, input=["Hello world"])
|
||||
self.assertEqual(len(response.data), 1)
|
||||
self.assertTrue(len(response.data[0].embedding) > 0)
|
||||
|
||||
def test_embedding_single_int_list(self):
|
||||
"""Test embedding with a List[int] or List[List[int]]]"""
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
response = client.embeddings.create(
|
||||
model=self.model,
|
||||
input=[[15339, 314, 703, 284, 612, 262, 10658, 10188, 286, 2061]],
|
||||
)
|
||||
self.assertEqual(len(response.data), 1)
|
||||
self.assertTrue(len(response.data[0].embedding) > 0)
|
||||
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
response = client.embeddings.create(
|
||||
model=self.model,
|
||||
input=[15339, 314, 703, 284, 612, 262, 10658, 10188, 286, 2061],
|
||||
)
|
||||
self.assertEqual(len(response.data), 1)
|
||||
self.assertTrue(len(response.data[0].embedding) > 0)
|
||||
|
||||
def test_empty_string_embedding(self):
|
||||
"""Test embedding an empty string."""
|
||||
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
# Text embedding example with empty string
|
||||
text = ""
|
||||
# Expect a BadRequestError for empty input
|
||||
with self.assertRaises(openai.BadRequestError) as cm:
|
||||
client.embeddings.create(
|
||||
model=self.model,
|
||||
input=text,
|
||||
)
|
||||
# check the status code
|
||||
self.assertEqual(cm.exception.status_code, 400)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
669
test/srt/openai_server/basic/test_openai_server.py
Normal file
669
test/srt/openai_server/basic/test_openai_server.py
Normal file
@@ -0,0 +1,669 @@
|
||||
"""
|
||||
python3 -m unittest openai_server.basic.test_openai_server.TestOpenAIServer.test_completion
|
||||
python3 -m unittest openai_server.basic.test_openai_server.TestOpenAIServer.test_completion_stream
|
||||
python3 -m unittest openai_server.basic.test_openai_server.TestOpenAIServer.test_chat_completion
|
||||
python3 -m unittest openai_server.basic.test_openai_server.TestOpenAIServer.test_chat_completion_stream
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import openai
|
||||
import requests
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.runners import TEST_RERANK_QUERY_DOCS
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_SMALL_CROSS_ENCODER_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestOpenAIServer(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def run_completion(
|
||||
self, echo, logprobs, use_list_input, parallel_sample_num, token_input
|
||||
):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
prompt = "The capital of France is"
|
||||
if token_input:
|
||||
prompt_input = self.tokenizer.encode(prompt)
|
||||
num_prompt_tokens = len(prompt_input)
|
||||
else:
|
||||
prompt_input = prompt
|
||||
num_prompt_tokens = len(self.tokenizer.encode(prompt))
|
||||
|
||||
if use_list_input:
|
||||
prompt_arg = [prompt_input, prompt_input]
|
||||
num_choices = len(prompt_arg)
|
||||
num_prompt_tokens *= 2
|
||||
else:
|
||||
prompt_arg = prompt_input
|
||||
num_choices = 1
|
||||
|
||||
response = client.completions.create(
|
||||
model=self.model,
|
||||
prompt=prompt_arg,
|
||||
temperature=0,
|
||||
max_tokens=32,
|
||||
echo=echo,
|
||||
logprobs=logprobs,
|
||||
n=parallel_sample_num,
|
||||
)
|
||||
|
||||
assert len(response.choices) == num_choices * parallel_sample_num
|
||||
|
||||
if echo:
|
||||
text = response.choices[0].text
|
||||
assert text.startswith(prompt)
|
||||
|
||||
if logprobs:
|
||||
assert response.choices[0].logprobs
|
||||
assert isinstance(response.choices[0].logprobs.tokens[0], str)
|
||||
assert isinstance(response.choices[0].logprobs.top_logprobs[1], dict)
|
||||
ret_num_top_logprobs = len(response.choices[0].logprobs.top_logprobs[1])
|
||||
|
||||
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map
|
||||
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
|
||||
assert ret_num_top_logprobs > 0
|
||||
|
||||
# when echo=True and request.logprobs>0, logprob_start_len is 0, so the first token's logprob would be None.
|
||||
if not echo:
|
||||
assert response.choices[0].logprobs.token_logprobs[0]
|
||||
|
||||
assert response.id
|
||||
assert response.created
|
||||
assert (
|
||||
response.usage.prompt_tokens == num_prompt_tokens
|
||||
), f"{response.usage.prompt_tokens} vs {num_prompt_tokens}"
|
||||
assert response.usage.completion_tokens > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
def run_completion_stream(
|
||||
self, echo, logprobs, use_list_input, parallel_sample_num, token_input
|
||||
):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
prompt = "The capital of France is"
|
||||
if token_input:
|
||||
prompt_input = self.tokenizer.encode(prompt)
|
||||
num_prompt_tokens = len(prompt_input)
|
||||
else:
|
||||
prompt_input = prompt
|
||||
num_prompt_tokens = len(self.tokenizer.encode(prompt))
|
||||
|
||||
if use_list_input:
|
||||
prompt_arg = [prompt_input, prompt_input]
|
||||
num_choices = len(prompt_arg)
|
||||
num_prompt_tokens *= 2
|
||||
else:
|
||||
prompt_arg = prompt_input
|
||||
num_choices = 1
|
||||
|
||||
generator = client.completions.create(
|
||||
model=self.model,
|
||||
prompt=prompt_arg,
|
||||
temperature=0,
|
||||
max_tokens=32,
|
||||
echo=echo,
|
||||
logprobs=logprobs,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
n=parallel_sample_num,
|
||||
)
|
||||
|
||||
is_firsts = {}
|
||||
for response in generator:
|
||||
usage = response.usage
|
||||
if usage is not None:
|
||||
assert usage.prompt_tokens > 0, f"usage.prompt_tokens was zero"
|
||||
assert usage.completion_tokens > 0, f"usage.completion_tokens was zero"
|
||||
assert usage.total_tokens > 0, f"usage.total_tokens was zero"
|
||||
continue
|
||||
|
||||
index = response.choices[0].index
|
||||
is_first = is_firsts.get(index, True)
|
||||
|
||||
if logprobs:
|
||||
assert response.choices[0].logprobs, f"no logprobs in response"
|
||||
assert isinstance(
|
||||
response.choices[0].logprobs.tokens[0], str
|
||||
), f"{response.choices[0].logprobs.tokens[0]} is not a string"
|
||||
if not (is_first and echo):
|
||||
assert isinstance(
|
||||
response.choices[0].logprobs.top_logprobs[0], dict
|
||||
), f"top_logprobs was not a dictionary"
|
||||
ret_num_top_logprobs = len(
|
||||
response.choices[0].logprobs.top_logprobs[0]
|
||||
)
|
||||
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map
|
||||
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
|
||||
assert ret_num_top_logprobs > 0, f"ret_num_top_logprobs was 0"
|
||||
|
||||
if is_first:
|
||||
if echo:
|
||||
assert response.choices[0].text.startswith(
|
||||
prompt
|
||||
), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {is_first}"
|
||||
is_firsts[index] = False
|
||||
assert response.id, f"no id in response"
|
||||
assert response.created, f"no created in response"
|
||||
|
||||
for index in [i for i in range(parallel_sample_num * num_choices)]:
|
||||
assert not is_firsts.get(
|
||||
index, True
|
||||
), f"index {index} is not found in the response"
|
||||
|
||||
def run_chat_completion(self, logprobs, parallel_sample_num):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
response = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the capital of France? Answer in a few words.",
|
||||
},
|
||||
],
|
||||
temperature=0,
|
||||
logprobs=logprobs is not None and logprobs > 0,
|
||||
top_logprobs=logprobs,
|
||||
n=parallel_sample_num,
|
||||
)
|
||||
|
||||
if logprobs:
|
||||
assert isinstance(
|
||||
response.choices[0].logprobs.content[0].top_logprobs[0].token, str
|
||||
)
|
||||
|
||||
ret_num_top_logprobs = len(
|
||||
response.choices[0].logprobs.content[0].top_logprobs
|
||||
)
|
||||
assert (
|
||||
ret_num_top_logprobs == logprobs
|
||||
), f"{ret_num_top_logprobs} vs {logprobs}"
|
||||
|
||||
assert len(response.choices) == parallel_sample_num
|
||||
assert response.choices[0].message.role == "assistant"
|
||||
assert isinstance(response.choices[0].message.content, str)
|
||||
assert response.id
|
||||
assert response.created
|
||||
assert response.usage.prompt_tokens > 0
|
||||
assert response.usage.completion_tokens > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
def run_chat_completion_stream(self, logprobs, parallel_sample_num=1):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
generator = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{"role": "user", "content": "What is the capital of France?"},
|
||||
],
|
||||
temperature=0,
|
||||
logprobs=logprobs is not None and logprobs > 0,
|
||||
top_logprobs=logprobs,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
n=parallel_sample_num,
|
||||
)
|
||||
|
||||
is_firsts = {}
|
||||
is_finished = {}
|
||||
finish_reason_counts = {}
|
||||
for response in generator:
|
||||
usage = response.usage
|
||||
if usage is not None:
|
||||
assert usage.prompt_tokens > 0, f"usage.prompt_tokens was zero"
|
||||
assert usage.completion_tokens > 0, f"usage.completion_tokens was zero"
|
||||
assert usage.total_tokens > 0, f"usage.total_tokens was zero"
|
||||
continue
|
||||
|
||||
index = response.choices[0].index
|
||||
finish_reason = response.choices[0].finish_reason
|
||||
if finish_reason is not None:
|
||||
is_finished[index] = True
|
||||
finish_reason_counts[index] = finish_reason_counts.get(index, 0) + 1
|
||||
|
||||
data = response.choices[0].delta
|
||||
|
||||
if is_firsts.get(index, True):
|
||||
assert (
|
||||
data.role == "assistant"
|
||||
), f"data.role was not 'assistant' for first chunk"
|
||||
is_firsts[index] = False
|
||||
continue
|
||||
|
||||
if logprobs and not is_finished.get(index, False):
|
||||
assert response.choices[0].logprobs, f"logprobs was not returned"
|
||||
assert isinstance(
|
||||
response.choices[0].logprobs.content[0].top_logprobs[0].token, str
|
||||
), f"top_logprobs token was not a string"
|
||||
assert isinstance(
|
||||
response.choices[0].logprobs.content[0].top_logprobs, list
|
||||
), f"top_logprobs was not a list"
|
||||
ret_num_top_logprobs = len(
|
||||
response.choices[0].logprobs.content[0].top_logprobs
|
||||
)
|
||||
assert (
|
||||
ret_num_top_logprobs == logprobs
|
||||
), f"{ret_num_top_logprobs} vs {logprobs}"
|
||||
|
||||
assert (
|
||||
isinstance(data.content, str)
|
||||
or isinstance(data.reasoning_content, str)
|
||||
or (isinstance(data.tool_calls, list) and len(data.tool_calls) > 0)
|
||||
or response.choices[0].finish_reason
|
||||
)
|
||||
assert response.id
|
||||
assert response.created
|
||||
|
||||
for index in [i for i in range(parallel_sample_num)]:
|
||||
assert not is_firsts.get(
|
||||
index, True
|
||||
), f"index {index} is not found in the response"
|
||||
|
||||
# Verify that each choice gets exactly one finish_reason chunk
|
||||
for index in range(parallel_sample_num):
|
||||
assert (
|
||||
index in finish_reason_counts
|
||||
), f"No finish_reason found for index {index}"
|
||||
assert (
|
||||
finish_reason_counts[index] == 1
|
||||
), f"Expected 1 finish_reason chunk for index {index}, got {finish_reason_counts[index]}"
|
||||
|
||||
def test_completion(self):
|
||||
for echo in [False, True]:
|
||||
for logprobs in [None, 5]:
|
||||
for use_list_input in [True, False]:
|
||||
for parallel_sample_num in [1, 2]:
|
||||
for token_input in [False, True]:
|
||||
self.run_completion(
|
||||
echo,
|
||||
logprobs,
|
||||
use_list_input,
|
||||
parallel_sample_num,
|
||||
token_input,
|
||||
)
|
||||
|
||||
def test_completion_stream(self):
|
||||
# parallel sampling and list input are not supported in streaming mode
|
||||
for echo in [False, True]:
|
||||
for logprobs in [None, 5]:
|
||||
for use_list_input in [True, False]:
|
||||
for parallel_sample_num in [1, 2]:
|
||||
for token_input in [False, True]:
|
||||
self.run_completion_stream(
|
||||
echo,
|
||||
logprobs,
|
||||
use_list_input,
|
||||
parallel_sample_num,
|
||||
token_input,
|
||||
)
|
||||
|
||||
def test_chat_completion(self):
|
||||
for logprobs in [None, 5]:
|
||||
for parallel_sample_num in [1, 2]:
|
||||
self.run_chat_completion(logprobs, parallel_sample_num)
|
||||
|
||||
def test_chat_completion_stream(self):
|
||||
for logprobs in [None, 5]:
|
||||
for parallel_sample_num in [1, 2]:
|
||||
self.run_chat_completion_stream(logprobs, parallel_sample_num)
|
||||
|
||||
def test_regex(self):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
regex = (
|
||||
r"""\{\n"""
|
||||
+ r""" "name": "[\w]+",\n"""
|
||||
+ r""" "population": [\d]+\n"""
|
||||
+ r"""\}"""
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{"role": "user", "content": "Introduce the capital of France."},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=128,
|
||||
extra_body={"regex": regex},
|
||||
)
|
||||
text = response.choices[0].message.content
|
||||
|
||||
try:
|
||||
js_obj = json.loads(text)
|
||||
except (TypeError, json.decoder.JSONDecodeError):
|
||||
print("JSONDecodeError", text)
|
||||
raise
|
||||
assert isinstance(js_obj["name"], str)
|
||||
assert isinstance(js_obj["population"], int)
|
||||
|
||||
def test_penalty(self):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{"role": "user", "content": "Introduce the capital of France."},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=32,
|
||||
frequency_penalty=1.0,
|
||||
)
|
||||
text = response.choices[0].message.content
|
||||
assert isinstance(text, str)
|
||||
|
||||
def test_response_prefill(self):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": """
|
||||
Extract the name, size, price, and color from this product description as a JSON object:
|
||||
|
||||
<description>
|
||||
The SmartHome Mini is a compact smart home assistant available in black or white for only $49.99. At just 5 inches wide, it lets you control lights, thermostats, and other connected devices via voice or app—no matter where you place it in your home. This affordable little hub brings convenient hands-free control to your smart devices.
|
||||
</description>
|
||||
""",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "{\n",
|
||||
},
|
||||
],
|
||||
temperature=0,
|
||||
extra_body={"continue_final_message": True},
|
||||
)
|
||||
|
||||
assert (
|
||||
response.choices[0]
|
||||
.message.content.strip()
|
||||
.startswith('"name": "SmartHome Mini",')
|
||||
)
|
||||
|
||||
def test_model_list(self):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
models = list(client.models.list())
|
||||
assert len(models) == 1
|
||||
assert isinstance(getattr(models[0], "max_model_len", None), int)
|
||||
|
||||
def test_retrieve_model(self):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
# Test retrieving an existing model
|
||||
retrieved_model = client.models.retrieve(self.model)
|
||||
self.assertEqual(retrieved_model.id, self.model)
|
||||
self.assertEqual(retrieved_model.root, self.model)
|
||||
|
||||
# Test retrieving a non-existent model
|
||||
with self.assertRaises(openai.NotFoundError):
|
||||
client.models.retrieve("non-existent-model")
|
||||
|
||||
|
||||
class TestOpenAIV1Rerank(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_SMALL_CROSS_ENCODER_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
cls.score_tolerance = 1e-2
|
||||
|
||||
# Configure embedding-specific args
|
||||
other_args = [
|
||||
"--is-embedding",
|
||||
"--enable-metrics",
|
||||
"--disable-radix-cache",
|
||||
"--chunked-prefill-size",
|
||||
"-1",
|
||||
"--attention-backend",
|
||||
"torch_native",
|
||||
]
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
other_args=other_args,
|
||||
)
|
||||
cls.base_url += "/v1/rerank"
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def run_rerank(self, query, docs):
|
||||
response = requests.post(
|
||||
self.base_url,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={"query": query, "documents": docs},
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
def test_rerank_single(self):
|
||||
"""Test single rerank request"""
|
||||
query = TEST_RERANK_QUERY_DOCS[0]["query"]
|
||||
docs = TEST_RERANK_QUERY_DOCS[0]["documents"]
|
||||
|
||||
response = self.run_rerank(query, docs)
|
||||
|
||||
self.assertEqual(len(response), 1)
|
||||
self.assertTrue(isinstance(response[0]["score"], float))
|
||||
self.assertTrue(isinstance(response[0]["document"], str))
|
||||
self.assertTrue(isinstance(response[0]["index"], int))
|
||||
|
||||
def test_rerank_batch(self):
|
||||
"""Test batch rerank request"""
|
||||
query = TEST_RERANK_QUERY_DOCS[1]["query"]
|
||||
docs = TEST_RERANK_QUERY_DOCS[1]["documents"]
|
||||
|
||||
response = self.run_rerank(query, docs)
|
||||
|
||||
self.assertEqual(len(response), 2)
|
||||
self.assertTrue(isinstance(response[0]["score"], float))
|
||||
self.assertTrue(isinstance(response[1]["score"], float))
|
||||
self.assertTrue(isinstance(response[0]["document"], str))
|
||||
self.assertTrue(isinstance(response[1]["document"], str))
|
||||
self.assertTrue(isinstance(response[0]["index"], int))
|
||||
self.assertTrue(isinstance(response[1]["index"], int))
|
||||
|
||||
|
||||
class TestOpenAIV1Score(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
)
|
||||
cls.base_url += "/v1/score"
|
||||
cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def run_score(
|
||||
self, query, items, label_token_ids, apply_softmax=False, item_first=False
|
||||
):
|
||||
response = requests.post(
|
||||
self.base_url,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"model": self.model,
|
||||
"query": query,
|
||||
"items": items,
|
||||
"label_token_ids": label_token_ids,
|
||||
"apply_softmax": apply_softmax,
|
||||
"item_first": item_first,
|
||||
},
|
||||
)
|
||||
return response.json()
|
||||
|
||||
def test_score_text_input(self):
|
||||
"""Test scoring with text input"""
|
||||
query = "The capital of France is"
|
||||
items = ["Paris", "London", "Berlin"]
|
||||
|
||||
# Get valid token IDs from the tokenizer
|
||||
label_token_ids = []
|
||||
for item in items:
|
||||
token_ids = self.tokenizer.encode(item, add_special_tokens=False)
|
||||
if not token_ids:
|
||||
self.fail(f"Failed to encode item: {item}")
|
||||
label_token_ids.append(token_ids[0])
|
||||
|
||||
response = self.run_score(query, items, label_token_ids, apply_softmax=True)
|
||||
|
||||
# Handle error responses
|
||||
if response.get("type") == "BadRequestError":
|
||||
self.fail(f"Score request failed with error: {response['message']}")
|
||||
|
||||
# Verify response structure
|
||||
self.assertIn("scores", response, "Response should have a 'scores' field")
|
||||
self.assertIsInstance(response["scores"], list, "scores should be a list")
|
||||
self.assertEqual(
|
||||
len(response["scores"]),
|
||||
len(items),
|
||||
"Number of scores should match number of items",
|
||||
)
|
||||
|
||||
# Each score should be a list of floats in the order of label_token_ids
|
||||
for i, score_list in enumerate(response["scores"]):
|
||||
self.assertIsInstance(score_list, list, f"Score {i} should be a list")
|
||||
self.assertEqual(
|
||||
len(score_list),
|
||||
len(label_token_ids),
|
||||
f"Score {i} length should match label_token_ids",
|
||||
)
|
||||
self.assertTrue(
|
||||
all(isinstance(v, float) for v in score_list),
|
||||
f"Score {i} values should be floats",
|
||||
)
|
||||
self.assertAlmostEqual(
|
||||
sum(score_list),
|
||||
1.0,
|
||||
places=6,
|
||||
msg=f"Score {i} probabilities should sum to 1",
|
||||
)
|
||||
|
||||
def test_score_token_input(self):
|
||||
"""Test scoring with token IDs input"""
|
||||
query = "The capital of France is"
|
||||
items = ["Paris", "London", "Berlin"]
|
||||
|
||||
# Get valid token IDs
|
||||
query_ids = self.tokenizer.encode(query, add_special_tokens=False)
|
||||
item_ids = [
|
||||
self.tokenizer.encode(item, add_special_tokens=False) for item in items
|
||||
]
|
||||
label_token_ids = [
|
||||
ids[0] for ids in item_ids if ids
|
||||
] # Get first token ID of each item
|
||||
|
||||
response = self.run_score(
|
||||
query_ids, item_ids, label_token_ids, apply_softmax=True
|
||||
)
|
||||
|
||||
# Handle error responses
|
||||
if response.get("type") == "BadRequestError":
|
||||
self.fail(f"Score request failed with error: {response['message']}")
|
||||
|
||||
# Verify response structure
|
||||
self.assertIn("scores", response, "Response should have a 'scores' field")
|
||||
self.assertIsInstance(response["scores"], list, "scores should be a list")
|
||||
self.assertEqual(
|
||||
len(response["scores"]),
|
||||
len(items),
|
||||
"Number of scores should match number of items",
|
||||
)
|
||||
|
||||
# Each score should be a list of floats in the order of label_token_ids
|
||||
for i, score_list in enumerate(response["scores"]):
|
||||
self.assertIsInstance(score_list, list, f"Score {i} should be a list")
|
||||
self.assertEqual(
|
||||
len(score_list),
|
||||
len(label_token_ids),
|
||||
f"Score {i} length should match label_token_ids",
|
||||
)
|
||||
self.assertTrue(
|
||||
all(isinstance(v, float) for v in score_list),
|
||||
f"Score {i} values should be floats",
|
||||
)
|
||||
self.assertAlmostEqual(
|
||||
sum(score_list),
|
||||
1.0,
|
||||
places=6,
|
||||
msg=f"Score {i} probabilities should sum to 1",
|
||||
)
|
||||
|
||||
def test_score_error_handling(self):
|
||||
"""Test error handling for invalid inputs"""
|
||||
query = "The capital of France is"
|
||||
items = ["Paris", "London", "Berlin"]
|
||||
|
||||
# Test with invalid token ID
|
||||
response = requests.post(
|
||||
self.base_url,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"model": self.model,
|
||||
"query": query,
|
||||
"items": items,
|
||||
"label_token_ids": [999999], # Invalid token ID
|
||||
"apply_softmax": True,
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
error_response = response.json()
|
||||
self.assertEqual(error_response["type"], "BadRequestError")
|
||||
self.assertIn("Token ID 999999 is out of vocabulary", error_response["message"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
368
test/srt/openai_server/basic/test_protocol.py
Normal file
368
test/srt/openai_server/basic/test_protocol.py
Normal file
@@ -0,0 +1,368 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for OpenAI API protocol models"""
|
||||
|
||||
import json
|
||||
import time
|
||||
import unittest
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import (
|
||||
BatchRequest,
|
||||
BatchResponse,
|
||||
ChatCompletionMessageContentImagePart,
|
||||
ChatCompletionMessageContentTextPart,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse,
|
||||
ChatCompletionTokenLogprob,
|
||||
ChatMessage,
|
||||
ChoiceLogprobs,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseChoice,
|
||||
DeltaMessage,
|
||||
EmbeddingObject,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
ErrorResponse,
|
||||
FileDeleteResponse,
|
||||
FileRequest,
|
||||
FileResponse,
|
||||
Function,
|
||||
FunctionResponse,
|
||||
JsonSchemaResponseFormat,
|
||||
LogProbs,
|
||||
ModelCard,
|
||||
ModelList,
|
||||
MultimodalEmbeddingInput,
|
||||
ResponseFormat,
|
||||
ScoringRequest,
|
||||
ScoringResponse,
|
||||
StreamOptions,
|
||||
StructuralTagResponseFormat,
|
||||
Tool,
|
||||
ToolCall,
|
||||
ToolChoice,
|
||||
TopLogprob,
|
||||
UsageInfo,
|
||||
)
|
||||
|
||||
|
||||
class TestModelCard(unittest.TestCase):
|
||||
"""Test ModelCard protocol model"""
|
||||
|
||||
def test_model_card_serialization(self):
|
||||
"""Test model card JSON serialization"""
|
||||
card = ModelCard(id="test-model", max_model_len=4096)
|
||||
data = card.model_dump()
|
||||
self.assertEqual(data["id"], "test-model")
|
||||
self.assertEqual(data["object"], "model")
|
||||
self.assertEqual(data["max_model_len"], 4096)
|
||||
|
||||
|
||||
class TestModelList(unittest.TestCase):
|
||||
"""Test ModelList protocol model"""
|
||||
|
||||
def test_empty_model_list(self):
|
||||
"""Test empty model list creation"""
|
||||
model_list = ModelList()
|
||||
self.assertEqual(model_list.object, "list")
|
||||
self.assertEqual(len(model_list.data), 0)
|
||||
|
||||
def test_model_list_with_cards(self):
|
||||
"""Test model list with model cards"""
|
||||
cards = [
|
||||
ModelCard(id="model-1"),
|
||||
ModelCard(id="model-2", max_model_len=2048),
|
||||
]
|
||||
model_list = ModelList(data=cards)
|
||||
self.assertEqual(len(model_list.data), 2)
|
||||
self.assertEqual(model_list.data[0].id, "model-1")
|
||||
self.assertEqual(model_list.data[1].id, "model-2")
|
||||
|
||||
|
||||
class TestCompletionRequest(unittest.TestCase):
|
||||
"""Test CompletionRequest protocol model"""
|
||||
|
||||
def test_basic_completion_request(self):
|
||||
"""Test basic completion request"""
|
||||
request = CompletionRequest(model="test-model", prompt="Hello world")
|
||||
self.assertEqual(request.model, "test-model")
|
||||
self.assertEqual(request.prompt, "Hello world")
|
||||
self.assertEqual(request.max_tokens, 16) # default
|
||||
self.assertEqual(request.temperature, 1.0) # default
|
||||
self.assertEqual(request.n, 1) # default
|
||||
self.assertFalse(request.stream) # default
|
||||
self.assertFalse(request.echo) # default
|
||||
|
||||
def test_completion_request_sglang_extensions(self):
|
||||
"""Test completion request with SGLang-specific extensions"""
|
||||
request = CompletionRequest(
|
||||
model="test-model",
|
||||
prompt="Hello",
|
||||
top_k=50,
|
||||
min_p=0.1,
|
||||
repetition_penalty=1.1,
|
||||
regex=r"\d+",
|
||||
json_schema='{"type": "object"}',
|
||||
lora_path="/path/to/lora",
|
||||
)
|
||||
self.assertEqual(request.top_k, 50)
|
||||
self.assertEqual(request.min_p, 0.1)
|
||||
self.assertEqual(request.repetition_penalty, 1.1)
|
||||
self.assertEqual(request.regex, r"\d+")
|
||||
self.assertEqual(request.json_schema, '{"type": "object"}')
|
||||
self.assertEqual(request.lora_path, "/path/to/lora")
|
||||
|
||||
def test_completion_request_validation_errors(self):
|
||||
"""Test completion request validation errors"""
|
||||
with self.assertRaises(ValidationError):
|
||||
CompletionRequest() # missing required fields
|
||||
|
||||
with self.assertRaises(ValidationError):
|
||||
CompletionRequest(model="test-model") # missing prompt
|
||||
|
||||
|
||||
class TestChatCompletionRequest(unittest.TestCase):
|
||||
"""Test ChatCompletionRequest protocol model"""
|
||||
|
||||
def test_basic_chat_completion_request(self):
|
||||
"""Test basic chat completion request"""
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
request = ChatCompletionRequest(model="test-model", messages=messages)
|
||||
self.assertEqual(request.model, "test-model")
|
||||
self.assertEqual(len(request.messages), 1)
|
||||
self.assertEqual(request.messages[0].role, "user")
|
||||
self.assertEqual(request.messages[0].content, "Hello")
|
||||
self.assertEqual(request.temperature, 0.7) # default
|
||||
self.assertFalse(request.stream) # default
|
||||
self.assertEqual(request.tool_choice, "none") # default when no tools
|
||||
|
||||
def test_chat_completion_tool_choice_validation(self):
|
||||
"""Test tool choice validation logic"""
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
|
||||
# No tools, tool_choice should default to "none"
|
||||
request1 = ChatCompletionRequest(model="test-model", messages=messages)
|
||||
self.assertEqual(request1.tool_choice, "none")
|
||||
|
||||
# With tools, tool_choice should default to "auto"
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {"name": "test_func", "description": "Test function"},
|
||||
}
|
||||
]
|
||||
request2 = ChatCompletionRequest(
|
||||
model="test-model", messages=messages, tools=tools
|
||||
)
|
||||
self.assertEqual(request2.tool_choice, "auto")
|
||||
|
||||
def test_chat_completion_sglang_extensions(self):
|
||||
"""Test chat completion with SGLang extensions"""
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
request = ChatCompletionRequest(
|
||||
model="test-model",
|
||||
messages=messages,
|
||||
top_k=40,
|
||||
min_p=0.05,
|
||||
separate_reasoning=False,
|
||||
stream_reasoning=False,
|
||||
chat_template_kwargs={"custom_param": "value"},
|
||||
)
|
||||
self.assertEqual(request.top_k, 40)
|
||||
self.assertEqual(request.min_p, 0.05)
|
||||
self.assertFalse(request.separate_reasoning)
|
||||
self.assertFalse(request.stream_reasoning)
|
||||
self.assertEqual(request.chat_template_kwargs, {"custom_param": "value"})
|
||||
|
||||
def test_chat_completion_reasoning_effort(self):
|
||||
"""Test chat completion with reasoning effort"""
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
request = ChatCompletionRequest(
|
||||
model="test-model",
|
||||
messages=messages,
|
||||
reasoning={
|
||||
"enabled": True,
|
||||
"reasoning_effort": "high",
|
||||
},
|
||||
)
|
||||
self.assertEqual(request.reasoning_effort, "high")
|
||||
self.assertEqual(request.chat_template_kwargs, {"thinking": True})
|
||||
|
||||
def test_chat_completion_json_format(self):
|
||||
"""Test chat completion json format"""
|
||||
transcript = "Good morning! It's 7:00 AM, and I'm just waking up. Today is going to be a busy day, "
|
||||
"so let's get started. First, I need to make a quick breakfast. I think I'll have some "
|
||||
"scrambled eggs and toast with a cup of coffee. While I'm cooking, I'll also check my "
|
||||
"emails to see if there's anything urgent."
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "The following is a voice message transcript. Only answer in JSON.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": transcript,
|
||||
},
|
||||
]
|
||||
|
||||
class VoiceNote(BaseModel):
|
||||
title: str = Field(description="A title for the voice note")
|
||||
summary: str = Field(
|
||||
description="A short one sentence summary of the voice note."
|
||||
)
|
||||
strict: Optional[bool] = True
|
||||
actionItems: List[str] = Field(
|
||||
description="A list of action items from the voice note"
|
||||
)
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
model="test-model",
|
||||
messages=messages,
|
||||
top_k=40,
|
||||
min_p=0.05,
|
||||
separate_reasoning=False,
|
||||
stream_reasoning=False,
|
||||
chat_template_kwargs={"custom_param": "value"},
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"schema": VoiceNote.model_json_schema(),
|
||||
},
|
||||
)
|
||||
res_format = request.response_format
|
||||
json_format = res_format.json_schema
|
||||
name = json_format.name
|
||||
schema = json_format.schema_
|
||||
strict = json_format.strict
|
||||
self.assertEqual(name, "VoiceNote")
|
||||
self.assertEqual(strict, True)
|
||||
self.assertNotIn("strict", schema["properties"])
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
model="test-model",
|
||||
messages=messages,
|
||||
top_k=40,
|
||||
min_p=0.05,
|
||||
separate_reasoning=False,
|
||||
stream_reasoning=False,
|
||||
chat_template_kwargs={"custom_param": "value"},
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "VoiceNote",
|
||||
"schema": VoiceNote.model_json_schema(),
|
||||
"strict": True,
|
||||
},
|
||||
},
|
||||
)
|
||||
res_format = request.response_format
|
||||
json_format = res_format.json_schema
|
||||
name = json_format.name
|
||||
schema = json_format.schema_
|
||||
strict = json_format.strict
|
||||
self.assertEqual(name, "VoiceNote")
|
||||
self.assertEqual(strict, True)
|
||||
|
||||
|
||||
class TestModelSerialization(unittest.TestCase):
|
||||
"""Test model serialization with hidden states"""
|
||||
|
||||
def test_hidden_states_excluded_when_none(self):
|
||||
"""Test that None hidden_states are excluded with exclude_none=True"""
|
||||
choice = ChatCompletionResponseChoice(
|
||||
index=0,
|
||||
message=ChatMessage(role="assistant", content="Hello"),
|
||||
finish_reason="stop",
|
||||
hidden_states=None,
|
||||
)
|
||||
|
||||
response = ChatCompletionResponse(
|
||||
id="test-id",
|
||||
model="test-model",
|
||||
choices=[choice],
|
||||
usage=UsageInfo(prompt_tokens=5, completion_tokens=1, total_tokens=6),
|
||||
)
|
||||
|
||||
# Test exclude_none serialization (should exclude None hidden_states)
|
||||
data = response.model_dump(exclude_none=True)
|
||||
self.assertNotIn("hidden_states", data["choices"][0])
|
||||
|
||||
def test_hidden_states_included_when_not_none(self):
|
||||
"""Test that non-None hidden_states are included"""
|
||||
choice = ChatCompletionResponseChoice(
|
||||
index=0,
|
||||
message=ChatMessage(role="assistant", content="Hello"),
|
||||
finish_reason="stop",
|
||||
hidden_states=[0.1, 0.2, 0.3],
|
||||
)
|
||||
|
||||
response = ChatCompletionResponse(
|
||||
id="test-id",
|
||||
model="test-model",
|
||||
choices=[choice],
|
||||
usage=UsageInfo(prompt_tokens=5, completion_tokens=1, total_tokens=6),
|
||||
)
|
||||
|
||||
# Test exclude_none serialization (should include non-None hidden_states)
|
||||
data = response.model_dump(exclude_none=True)
|
||||
self.assertIn("hidden_states", data["choices"][0])
|
||||
self.assertEqual(data["choices"][0]["hidden_states"], [0.1, 0.2, 0.3])
|
||||
|
||||
|
||||
class TestValidationEdgeCases(unittest.TestCase):
|
||||
"""Test edge cases and validation scenarios"""
|
||||
|
||||
def test_invalid_tool_choice_type(self):
|
||||
"""Test invalid tool choice type"""
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
with self.assertRaises(ValidationError):
|
||||
ChatCompletionRequest(
|
||||
model="test-model", messages=messages, tool_choice=123
|
||||
)
|
||||
|
||||
def test_negative_token_limits(self):
|
||||
"""Test negative token limits"""
|
||||
with self.assertRaises(ValidationError):
|
||||
CompletionRequest(model="test-model", prompt="Hello", max_tokens=-1)
|
||||
|
||||
def test_model_serialization_roundtrip(self):
|
||||
"""Test that models can be serialized and deserialized"""
|
||||
original_request = ChatCompletionRequest(
|
||||
model="test-model",
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
temperature=0.7,
|
||||
max_tokens=100,
|
||||
)
|
||||
|
||||
# Serialize to dict
|
||||
data = original_request.model_dump()
|
||||
|
||||
# Deserialize back
|
||||
restored_request = ChatCompletionRequest(**data)
|
||||
|
||||
self.assertEqual(restored_request.model, original_request.model)
|
||||
self.assertEqual(restored_request.temperature, original_request.temperature)
|
||||
self.assertEqual(restored_request.max_tokens, original_request.max_tokens)
|
||||
self.assertEqual(len(restored_request.messages), len(original_request.messages))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
425
test/srt/openai_server/basic/test_serving_chat.py
Normal file
425
test/srt/openai_server/basic/test_serving_chat.py
Normal file
@@ -0,0 +1,425 @@
|
||||
"""
|
||||
Unit-tests for OpenAIServingChat — rewritten to use only the std-lib 'unittest'.
|
||||
Run with either:
|
||||
python tests/test_serving_chat_unit.py -v
|
||||
or
|
||||
python -m unittest discover -s tests -p "test_*unit.py" -v
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import unittest
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
MessageProcessingResult,
|
||||
)
|
||||
from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
|
||||
|
||||
class _MockTokenizerManager:
|
||||
"""Minimal mock that satisfies OpenAIServingChat."""
|
||||
|
||||
def __init__(self):
|
||||
self.model_config = Mock(is_multimodal=False)
|
||||
self.server_args = Mock(
|
||||
enable_cache_report=False,
|
||||
tool_call_parser="hermes",
|
||||
reasoning_parser=None,
|
||||
)
|
||||
self.chat_template_name: Optional[str] = "llama-3"
|
||||
|
||||
# tokenizer stub
|
||||
self.tokenizer = Mock()
|
||||
self.tokenizer.encode.return_value = [1, 2, 3, 4, 5]
|
||||
self.tokenizer.decode.return_value = "Test response"
|
||||
self.tokenizer.chat_template = None
|
||||
self.tokenizer.bos_token_id = 1
|
||||
|
||||
# async generator stub for generate_request
|
||||
async def _mock_generate():
|
||||
yield {
|
||||
"text": "Test response",
|
||||
"meta_info": {
|
||||
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 5,
|
||||
"cached_tokens": 0,
|
||||
"finish_reason": {"type": "stop", "matched": None},
|
||||
"output_token_logprobs": [(0.1, 1, "Test"), (0.2, 2, "response")],
|
||||
"output_top_logprobs": None,
|
||||
},
|
||||
"index": 0,
|
||||
}
|
||||
|
||||
self.generate_request = Mock(return_value=_mock_generate())
|
||||
self.create_abort_task = Mock()
|
||||
|
||||
|
||||
class _MockTemplateManager:
|
||||
"""Minimal mock for TemplateManager."""
|
||||
|
||||
def __init__(self):
|
||||
self.chat_template_name: Optional[str] = "llama-3"
|
||||
self.jinja_template_content_format: Optional[str] = None
|
||||
self.completion_template_name: Optional[str] = None
|
||||
|
||||
|
||||
class ServingChatTestCase(unittest.TestCase):
|
||||
# ------------- common fixtures -------------
|
||||
def setUp(self):
|
||||
self.tm = _MockTokenizerManager()
|
||||
self.template_manager = _MockTemplateManager()
|
||||
self.chat = OpenAIServingChat(self.tm, self.template_manager)
|
||||
|
||||
# frequently reused requests
|
||||
self.basic_req = ChatCompletionRequest(
|
||||
model="x",
|
||||
messages=[{"role": "user", "content": "Hi?"}],
|
||||
temperature=0.7,
|
||||
max_tokens=100,
|
||||
stream=False,
|
||||
)
|
||||
self.stream_req = ChatCompletionRequest(
|
||||
model="x",
|
||||
messages=[{"role": "user", "content": "Hi?"}],
|
||||
temperature=0.7,
|
||||
max_tokens=100,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
self.fastapi_request = Mock(spec=Request)
|
||||
self.fastapi_request.headers = {}
|
||||
|
||||
# ------------- conversion tests -------------
|
||||
def test_convert_to_internal_request_single(self):
|
||||
with patch(
|
||||
"sglang.srt.entrypoints.openai.serving_chat.generate_chat_conv"
|
||||
) as conv_mock, patch.object(self.chat, "_process_messages") as proc_mock:
|
||||
conv_ins = Mock()
|
||||
conv_ins.get_prompt.return_value = "Test prompt"
|
||||
conv_ins.image_data = conv_ins.audio_data = None
|
||||
conv_ins.modalities = []
|
||||
conv_ins.stop_str = ["</s>"]
|
||||
conv_mock.return_value = conv_ins
|
||||
|
||||
proc_mock.return_value = MessageProcessingResult(
|
||||
"Test prompt",
|
||||
[1, 2, 3],
|
||||
None,
|
||||
None,
|
||||
[],
|
||||
["</s>"],
|
||||
None,
|
||||
)
|
||||
|
||||
adapted, processed = self.chat._convert_to_internal_request(self.basic_req)
|
||||
self.assertIsInstance(adapted, GenerateReqInput)
|
||||
self.assertFalse(adapted.stream)
|
||||
self.assertEqual(processed, self.basic_req)
|
||||
|
||||
def test_stop_str_isolation_between_requests(self):
|
||||
"""Test that stop strings from one request don't affect subsequent requests.
|
||||
|
||||
This tests the fix for the bug where conv.stop_str was being mutated globally,
|
||||
causing stop strings from one request to persist in subsequent requests.
|
||||
"""
|
||||
# Mock conversation template with initial stop_str
|
||||
initial_stop_str = ["\n"]
|
||||
|
||||
with patch(
|
||||
"sglang.srt.entrypoints.openai.serving_chat.generate_chat_conv"
|
||||
) as conv_mock:
|
||||
# Create a mock conversation object that will be returned by generate_chat_conv
|
||||
conv_ins = Mock()
|
||||
conv_ins.get_prompt.return_value = "Test prompt"
|
||||
conv_ins.image_data = None
|
||||
conv_ins.audio_data = None
|
||||
conv_ins.modalities = []
|
||||
conv_ins.stop_str = (
|
||||
initial_stop_str.copy()
|
||||
) # Template's default stop strings
|
||||
conv_mock.return_value = conv_ins
|
||||
|
||||
# First request with additional stop string
|
||||
req1 = ChatCompletionRequest(
|
||||
model="x",
|
||||
messages=[{"role": "user", "content": "First request"}],
|
||||
stop=["CUSTOM_STOP"],
|
||||
)
|
||||
|
||||
# Call the actual _apply_conversation_template method (not mocked)
|
||||
result1 = self.chat._apply_conversation_template(req1, is_multimodal=False)
|
||||
|
||||
# Verify first request has both stop strings
|
||||
expected_stop1 = initial_stop_str + ["CUSTOM_STOP"]
|
||||
self.assertEqual(result1.stop, expected_stop1)
|
||||
|
||||
# Verify the original template's stop_str wasn't mutated after first request
|
||||
self.assertEqual(conv_ins.stop_str, initial_stop_str)
|
||||
|
||||
# Second request without additional stop string
|
||||
req2 = ChatCompletionRequest(
|
||||
model="x",
|
||||
messages=[{"role": "user", "content": "Second request"}],
|
||||
# No custom stop strings
|
||||
)
|
||||
result2 = self.chat._apply_conversation_template(req2, is_multimodal=False)
|
||||
|
||||
# Verify second request only has original stop strings (no CUSTOM_STOP from req1)
|
||||
self.assertEqual(result2.stop, initial_stop_str)
|
||||
self.assertNotIn("CUSTOM_STOP", result2.stop)
|
||||
self.assertEqual(conv_ins.stop_str, initial_stop_str)
|
||||
|
||||
# ------------- sampling-params -------------
|
||||
def test_sampling_param_build(self):
|
||||
req = ChatCompletionRequest(
|
||||
model="x",
|
||||
messages=[{"role": "user", "content": "Hi"}],
|
||||
temperature=0.8,
|
||||
max_tokens=150,
|
||||
min_tokens=5,
|
||||
top_p=0.9,
|
||||
stop=["</s>"],
|
||||
)
|
||||
with patch.object(
|
||||
self.chat,
|
||||
"_process_messages",
|
||||
return_value=("Prompt", [1], None, None, [], ["</s>"], None),
|
||||
):
|
||||
params = self.chat._build_sampling_params(req, ["</s>"], None)
|
||||
self.assertEqual(params["temperature"], 0.8)
|
||||
self.assertEqual(params["max_new_tokens"], 150)
|
||||
self.assertEqual(params["min_new_tokens"], 5)
|
||||
self.assertEqual(params["stop"], ["</s>"])
|
||||
|
||||
async def test_unstreamed_tool_args_completion(self):
|
||||
"""Test that remaining tool call arguments are sent when generation finishes."""
|
||||
|
||||
# Mock FunctionCallParser with detector that has partial tool call data
|
||||
mock_parser = Mock()
|
||||
mock_detector = Mock()
|
||||
|
||||
# Simulate a tool call that was partially streamed
|
||||
mock_detector.prev_tool_call_arr = [
|
||||
{
|
||||
"name": "get_weather",
|
||||
"arguments": {"location": "San Francisco", "unit": "celsius"},
|
||||
}
|
||||
]
|
||||
mock_detector.streamed_args_for_tool = [
|
||||
'{"location": "San Francisco"' # Partial arguments streamed so far
|
||||
]
|
||||
mock_parser.detector = mock_detector
|
||||
|
||||
content = {
|
||||
"meta_info": {
|
||||
"id": "chatcmpl-test123",
|
||||
}
|
||||
}
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
model="test",
|
||||
messages=[{"role": "user", "content": "What's the weather?"}],
|
||||
tools=[{"type": "function", "function": {"name": "get_weather"}}],
|
||||
)
|
||||
|
||||
# Test the completion method
|
||||
result = self.chat._check_for_unstreamed_tool_args(
|
||||
parser=mock_parser,
|
||||
content=content,
|
||||
request=request,
|
||||
finish_reason_type="stop",
|
||||
index=0,
|
||||
)
|
||||
|
||||
# Should return a chunk with remaining arguments
|
||||
self.assertIsNotNone(result, "Should return chunk with remaining arguments")
|
||||
self.assertIn('"arguments":', result, "Should contain arguments field")
|
||||
self.assertIn(
|
||||
', "unit": "celsius"}', result, "Should contain remaining arguments"
|
||||
)
|
||||
self.assertIn(
|
||||
'"finish_reason":null',
|
||||
result,
|
||||
"Should not include finish_reason in completion chunk",
|
||||
)
|
||||
|
||||
async def test_unstreamed_tool_args_no_completion_needed(self):
|
||||
"""Test that no completion chunk is sent when all arguments were already streamed."""
|
||||
|
||||
# Mock FunctionCallParser with detector that has complete tool call data
|
||||
mock_parser = Mock()
|
||||
mock_detector = Mock()
|
||||
|
||||
# Simulate a tool call that was completely streamed
|
||||
mock_detector.prev_tool_call_arr = [
|
||||
{"name": "get_weather", "arguments": {"location": "San Francisco"}}
|
||||
]
|
||||
mock_detector.streamed_args_for_tool = [
|
||||
'{"location": "San Francisco"}' # All arguments already streamed
|
||||
]
|
||||
mock_parser.detector = mock_detector
|
||||
|
||||
content = {
|
||||
"meta_info": {
|
||||
"id": "chatcmpl-test123",
|
||||
}
|
||||
}
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
model="test",
|
||||
messages=[{"role": "user", "content": "What's the weather?"}],
|
||||
tools=[{"type": "function", "function": {"name": "get_weather"}}],
|
||||
)
|
||||
|
||||
# Test the completion method
|
||||
result = self.chat._check_for_unstreamed_tool_args(
|
||||
parser=mock_parser,
|
||||
content=content,
|
||||
request=request,
|
||||
finish_reason_type="stop",
|
||||
index=0,
|
||||
)
|
||||
|
||||
# Should return None since no completion is needed
|
||||
self.assertIsNone(result, "Should return None when no completion is needed")
|
||||
|
||||
async def test_unstreamed_tool_args_no_parser_data(self):
|
||||
"""Test that no completion chunk is sent when parser has no tool call data."""
|
||||
|
||||
# Mock FunctionCallParser with empty detector
|
||||
mock_parser = Mock()
|
||||
mock_detector = Mock()
|
||||
mock_detector.prev_tool_call_arr = []
|
||||
mock_detector.streamed_args_for_tool = []
|
||||
mock_parser.detector = mock_detector
|
||||
|
||||
content = {
|
||||
"meta_info": {
|
||||
"id": "chatcmpl-test123",
|
||||
}
|
||||
}
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
model="test",
|
||||
messages=[{"role": "user", "content": "What's the weather?"}],
|
||||
tools=[{"type": "function", "function": {"name": "get_weather"}}],
|
||||
)
|
||||
|
||||
# Test the completion method
|
||||
result = self.chat._check_for_unstreamed_tool_args(
|
||||
parser=mock_parser,
|
||||
content=content,
|
||||
request=request,
|
||||
finish_reason_type="stop",
|
||||
index=0,
|
||||
)
|
||||
|
||||
# Should return None since there's no parser data
|
||||
self.assertIsNone(
|
||||
result, "Should return None when parser has no tool call data"
|
||||
)
|
||||
|
||||
# ------------- kimi_k2 tool_call_id formatting -------------
|
||||
def test_kimi_k2_non_streaming_tool_call_id_format(self):
|
||||
"""Ensure non-streaming tool_call.id matches functions.{name}:{index} for kimi_k2 parser."""
|
||||
|
||||
# Force kimi_k2 parser
|
||||
self.chat.tool_call_parser = "kimi_k2"
|
||||
|
||||
# Mock FunctionCallParser.parse_non_stream to return one tool call
|
||||
with patch(
|
||||
"sglang.srt.entrypoints.openai.serving_chat.FunctionCallParser"
|
||||
) as ParserMock:
|
||||
parser_instance = ParserMock.return_value
|
||||
|
||||
# Build a mock ToolCallItem-like object
|
||||
call_info = Mock()
|
||||
call_info.name = "get_weather"
|
||||
call_info.parameters = '{"city":"Paris"}'
|
||||
call_info.tool_index = 0
|
||||
|
||||
parser_instance.has_tool_call.return_value = True
|
||||
parser_instance.parse_non_stream.return_value = ("", [call_info])
|
||||
|
||||
finish_reason = {"type": "stop", "matched": None}
|
||||
tools = [
|
||||
{"type": "function", "function": {"name": "get_weather"}},
|
||||
]
|
||||
|
||||
tool_calls, remaining_text, _ = self.chat._process_tool_calls(
|
||||
text="<|tool_calls_section_begin|>...",
|
||||
tools=tools,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
|
||||
self.assertIsNotNone(tool_calls)
|
||||
self.assertEqual(len(tool_calls), 1)
|
||||
self.assertEqual(tool_calls[0].id, "functions.get_weather:0")
|
||||
self.assertEqual(tool_calls[0].function.name, "get_weather")
|
||||
|
||||
def test_kimi_k2_streaming_tool_call_id_format(self):
|
||||
"""Ensure streaming first chunk tool_call.id matches functions.{name}:{index} for kimi_k2 parser."""
|
||||
|
||||
# Force kimi_k2 parser
|
||||
self.chat.tool_call_parser = "kimi_k2"
|
||||
|
||||
# Prepare request with tools
|
||||
req = ChatCompletionRequest(
|
||||
model="x",
|
||||
messages=[{"role": "user", "content": "Hi?"}],
|
||||
tools=[{"type": "function", "function": {"name": "get_weather"}}],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Patch FunctionCallParser used inside _process_tool_call_stream
|
||||
with patch(
|
||||
"sglang.srt.entrypoints.openai.serving_chat.FunctionCallParser"
|
||||
) as ParserMock:
|
||||
parser_instance = ParserMock.return_value
|
||||
|
||||
# First call returns one ToolCallItem-like chunk (with name)
|
||||
first_chunk_call = Mock()
|
||||
first_chunk_call.tool_index = 0
|
||||
first_chunk_call.name = "get_weather"
|
||||
first_chunk_call.parameters = ""
|
||||
parser_instance.parse_stream_chunk.side_effect = [
|
||||
("", [first_chunk_call]),
|
||||
("", []),
|
||||
]
|
||||
|
||||
async def collect_first_tool_chunk():
|
||||
gen = self.chat._process_tool_call_stream(
|
||||
index=0,
|
||||
delta="irrelevant",
|
||||
parser_dict={},
|
||||
content={"meta_info": {"id": "chatcmpl-test"}},
|
||||
request=req,
|
||||
has_tool_calls={},
|
||||
)
|
||||
# Get first yielded SSE line
|
||||
line = None
|
||||
async for emitted in gen:
|
||||
line = emitted
|
||||
break
|
||||
return line
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
line = loop.run_until_complete(collect_first_tool_chunk())
|
||||
self.assertIsNotNone(line)
|
||||
self.assertTrue(line.startswith("data: "))
|
||||
|
||||
payload = json.loads(line[len("data: ") :])
|
||||
tool_calls = payload["choices"][0]["delta"]["tool_calls"]
|
||||
self.assertEqual(tool_calls[0]["id"], "functions.get_weather:0")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
157
test/srt/openai_server/basic/test_serving_completions.py
Normal file
157
test/srt/openai_server/basic/test_serving_completions.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""
|
||||
Unit-tests for the refactored completions-serving handler (no pytest).
|
||||
Run with:
|
||||
python -m unittest tests.test_serving_completions_unit -v
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from typing import Optional
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import CompletionRequest
|
||||
from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompletion
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
|
||||
|
||||
class _MockTemplateManager:
|
||||
"""Minimal mock for TemplateManager."""
|
||||
|
||||
def __init__(self):
|
||||
self.chat_template_name: Optional[str] = None
|
||||
self.jinja_template_content_format: Optional[str] = None
|
||||
self.completion_template_name: Optional[str] = (
|
||||
None # Set to None to avoid template processing
|
||||
)
|
||||
|
||||
|
||||
class ServingCompletionTestCase(unittest.TestCase):
|
||||
"""Bundle all prompt/echo tests in one TestCase."""
|
||||
|
||||
# ---------- shared test fixtures ----------
|
||||
def setUp(self):
|
||||
# build the mock TokenizerManager once for every test
|
||||
tm = Mock(spec=TokenizerManager)
|
||||
|
||||
tm.tokenizer = Mock()
|
||||
tm.tokenizer.encode.return_value = [1, 2, 3, 4]
|
||||
tm.tokenizer.decode.return_value = "decoded text"
|
||||
tm.tokenizer.bos_token_id = 1
|
||||
|
||||
tm.model_config = Mock(is_multimodal=False)
|
||||
tm.server_args = Mock(enable_cache_report=False)
|
||||
|
||||
tm.generate_request = AsyncMock()
|
||||
tm.create_abort_task = Mock()
|
||||
|
||||
self.template_manager = _MockTemplateManager()
|
||||
self.sc = OpenAIServingCompletion(tm, self.template_manager)
|
||||
|
||||
# ---------- prompt-handling ----------
|
||||
def test_single_string_prompt(self):
|
||||
req = CompletionRequest(model="x", prompt="Hello world", max_tokens=100)
|
||||
internal, _ = self.sc._convert_to_internal_request(req)
|
||||
self.assertEqual(internal.text, "Hello world")
|
||||
|
||||
def test_single_token_ids_prompt(self):
|
||||
req = CompletionRequest(model="x", prompt=[1, 2, 3, 4], max_tokens=100)
|
||||
internal, _ = self.sc._convert_to_internal_request(req)
|
||||
self.assertEqual(internal.input_ids, [1, 2, 3, 4])
|
||||
|
||||
# ---------- echo-handling ----------
|
||||
def test_echo_with_string_prompt_streaming(self):
|
||||
req = CompletionRequest(model="x", prompt="Hello", max_tokens=1, echo=True)
|
||||
self.assertEqual(self.sc._get_echo_text(req, 0), "Hello")
|
||||
|
||||
def test_echo_with_list_of_strings_streaming(self):
|
||||
req = CompletionRequest(
|
||||
model="x", prompt=["A", "B"], max_tokens=1, echo=True, n=1
|
||||
)
|
||||
self.assertEqual(self.sc._get_echo_text(req, 0), "A")
|
||||
self.assertEqual(self.sc._get_echo_text(req, 1), "B")
|
||||
|
||||
def test_echo_with_token_ids_streaming(self):
|
||||
req = CompletionRequest(model="x", prompt=[1, 2, 3], max_tokens=1, echo=True)
|
||||
self.sc.tokenizer_manager.tokenizer.decode.return_value = "decoded_prompt"
|
||||
self.assertEqual(self.sc._get_echo_text(req, 0), "decoded_prompt")
|
||||
|
||||
def test_echo_with_multiple_token_ids_streaming(self):
|
||||
req = CompletionRequest(
|
||||
model="x", prompt=[[1, 2], [3, 4]], max_tokens=1, echo=True, n=1
|
||||
)
|
||||
self.sc.tokenizer_manager.tokenizer.decode.return_value = "decoded"
|
||||
self.assertEqual(self.sc._get_echo_text(req, 0), "decoded")
|
||||
|
||||
def test_prepare_echo_prompts_non_streaming(self):
|
||||
# single string
|
||||
req = CompletionRequest(model="x", prompt="Hi", echo=True)
|
||||
self.assertEqual(self.sc._prepare_echo_prompts(req), ["Hi"])
|
||||
|
||||
# list of strings
|
||||
req = CompletionRequest(model="x", prompt=["Hi", "Yo"], echo=True)
|
||||
self.assertEqual(self.sc._prepare_echo_prompts(req), ["Hi", "Yo"])
|
||||
|
||||
# token IDs
|
||||
req = CompletionRequest(model="x", prompt=[1, 2, 3], echo=True)
|
||||
self.sc.tokenizer_manager.tokenizer.decode.return_value = "decoded"
|
||||
self.assertEqual(self.sc._prepare_echo_prompts(req), ["decoded"])
|
||||
|
||||
# ---------- response_format handling ----------
|
||||
def test_response_format_json_object(self):
|
||||
"""Test that response_format json_object is correctly processed in sampling params."""
|
||||
req = CompletionRequest(
|
||||
model="x",
|
||||
prompt="Generate a JSON object:",
|
||||
max_tokens=100,
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
sampling_params = self.sc._build_sampling_params(req)
|
||||
self.assertEqual(sampling_params["json_schema"], '{"type": "object"}')
|
||||
|
||||
def test_response_format_json_schema(self):
|
||||
"""Test that response_format json_schema is correctly processed in sampling params."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}, "age": {"type": "integer"}},
|
||||
}
|
||||
req = CompletionRequest(
|
||||
model="x",
|
||||
prompt="Generate a JSON object:",
|
||||
max_tokens=100,
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": {"name": "person", "schema": schema},
|
||||
},
|
||||
)
|
||||
sampling_params = self.sc._build_sampling_params(req)
|
||||
# The schema should be converted to string by convert_json_schema_to_str
|
||||
self.assertIn("json_schema", sampling_params)
|
||||
self.assertIsInstance(sampling_params["json_schema"], str)
|
||||
|
||||
def test_response_format_structural_tag(self):
|
||||
"""Test that response_format structural_tag is correctly processed in sampling params."""
|
||||
req = CompletionRequest(
|
||||
model="x",
|
||||
prompt="Generate structured output:",
|
||||
max_tokens=100,
|
||||
response_format={
|
||||
"type": "structural_tag",
|
||||
"structures": [{"begin": "<data>", "end": "</data>"}],
|
||||
"triggers": ["<data>"],
|
||||
},
|
||||
)
|
||||
sampling_params = self.sc._build_sampling_params(req)
|
||||
# The structural_tag should be processed
|
||||
self.assertIn("structural_tag", sampling_params)
|
||||
self.assertIsInstance(sampling_params["structural_tag"], str)
|
||||
|
||||
def test_response_format_none(self):
|
||||
"""Test that no response_format doesn't add extra constraints."""
|
||||
req = CompletionRequest(model="x", prompt="Generate text:", max_tokens=100)
|
||||
sampling_params = self.sc._build_sampling_params(req)
|
||||
# Should not have json_schema or structural_tag from response_format
|
||||
# (but might have json_schema from the legacy json_schema field)
|
||||
self.assertIsNone(sampling_params.get("structural_tag"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
145
test/srt/openai_server/basic/test_serving_embedding.py
Normal file
145
test/srt/openai_server/basic/test_serving_embedding.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""
|
||||
Unit tests for the OpenAIServingEmbedding class from serving_embedding.py.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
import uuid
|
||||
from unittest.mock import Mock
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import (
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
MultimodalEmbeddingInput,
|
||||
)
|
||||
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
from sglang.srt.managers.io_struct import EmbeddingReqInput
|
||||
|
||||
|
||||
# Mock TokenizerManager for embedding tests
|
||||
class _MockTokenizerManager:
|
||||
def __init__(self):
|
||||
self.model_config = Mock()
|
||||
self.model_config.is_multimodal = False
|
||||
self.server_args = Mock()
|
||||
self.server_args.enable_cache_report = False
|
||||
self.model_path = "test-model"
|
||||
|
||||
# Mock tokenizer
|
||||
self.tokenizer = Mock()
|
||||
self.tokenizer.encode = Mock(return_value=[1, 2, 3, 4, 5])
|
||||
self.tokenizer.decode = Mock(return_value="Test embedding input")
|
||||
self.tokenizer.chat_template = None
|
||||
self.tokenizer.bos_token_id = 1
|
||||
|
||||
# Mock generate_request method for embeddings
|
||||
async def mock_generate_embedding():
|
||||
yield {
|
||||
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5] * 20, # 100-dim embedding
|
||||
"meta_info": {
|
||||
"id": f"embd-{uuid.uuid4()}",
|
||||
"prompt_tokens": 5,
|
||||
},
|
||||
}
|
||||
|
||||
self.generate_request = Mock(return_value=mock_generate_embedding())
|
||||
|
||||
|
||||
# Mock TemplateManager for embedding tests
|
||||
class _MockTemplateManager:
|
||||
def __init__(self):
|
||||
self.chat_template_name = None # None for embeddings usually
|
||||
self.jinja_template_content_format = None
|
||||
self.completion_template_name = None
|
||||
|
||||
|
||||
class ServingEmbeddingTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.tokenizer_manager = _MockTokenizerManager()
|
||||
self.template_manager = _MockTemplateManager()
|
||||
self.serving_embedding = OpenAIServingEmbedding(
|
||||
self.tokenizer_manager, self.template_manager
|
||||
)
|
||||
|
||||
self.request = Mock(spec=Request)
|
||||
self.request.headers = {}
|
||||
|
||||
self.basic_req = EmbeddingRequest(
|
||||
model="test-model",
|
||||
input="Hello, how are you?",
|
||||
encoding_format="float",
|
||||
)
|
||||
self.list_req = EmbeddingRequest(
|
||||
model="test-model",
|
||||
input=["Hello, how are you?", "I am fine, thank you!"],
|
||||
encoding_format="float",
|
||||
)
|
||||
self.multimodal_req = EmbeddingRequest(
|
||||
model="test-model",
|
||||
input=[
|
||||
MultimodalEmbeddingInput(text="Hello", image="base64_image_data"),
|
||||
MultimodalEmbeddingInput(text="World", image=None),
|
||||
],
|
||||
encoding_format="float",
|
||||
)
|
||||
self.token_ids_req = EmbeddingRequest(
|
||||
model="test-model",
|
||||
input=[1, 2, 3, 4, 5],
|
||||
encoding_format="float",
|
||||
)
|
||||
|
||||
def test_convert_single_string_request(self):
|
||||
"""Test converting single string request to internal format."""
|
||||
adapted_request, processed_request = (
|
||||
self.serving_embedding._convert_to_internal_request(self.basic_req)
|
||||
)
|
||||
|
||||
self.assertIsInstance(adapted_request, EmbeddingReqInput)
|
||||
self.assertEqual(adapted_request.text, "Hello, how are you?")
|
||||
# self.assertEqual(adapted_request.rid, "test-id")
|
||||
self.assertEqual(processed_request, self.basic_req)
|
||||
|
||||
def test_convert_list_string_request(self):
|
||||
"""Test converting list of strings request to internal format."""
|
||||
adapted_request, processed_request = (
|
||||
self.serving_embedding._convert_to_internal_request(self.list_req)
|
||||
)
|
||||
|
||||
self.assertIsInstance(adapted_request, EmbeddingReqInput)
|
||||
self.assertEqual(
|
||||
adapted_request.text, ["Hello, how are you?", "I am fine, thank you!"]
|
||||
)
|
||||
# self.assertEqual(adapted_request.rid, "test-id")
|
||||
self.assertEqual(processed_request, self.list_req)
|
||||
|
||||
def test_convert_token_ids_request(self):
|
||||
"""Test converting token IDs request to internal format."""
|
||||
adapted_request, processed_request = (
|
||||
self.serving_embedding._convert_to_internal_request(self.token_ids_req)
|
||||
)
|
||||
|
||||
self.assertIsInstance(adapted_request, EmbeddingReqInput)
|
||||
self.assertEqual(adapted_request.input_ids, [1, 2, 3, 4, 5])
|
||||
# self.assertEqual(adapted_request.rid, "test-id")
|
||||
self.assertEqual(processed_request, self.token_ids_req)
|
||||
|
||||
def test_convert_multimodal_request(self):
|
||||
"""Test converting multimodal request to internal format."""
|
||||
adapted_request, processed_request = (
|
||||
self.serving_embedding._convert_to_internal_request(self.multimodal_req)
|
||||
)
|
||||
|
||||
self.assertIsInstance(adapted_request, EmbeddingReqInput)
|
||||
# Should extract text and images separately
|
||||
self.assertEqual(len(adapted_request.text), 2)
|
||||
self.assertIn("Hello", adapted_request.text)
|
||||
self.assertIn("World", adapted_request.text)
|
||||
self.assertEqual(adapted_request.image_data[0], "base64_image_data")
|
||||
self.assertIsNone(adapted_request.image_data[1])
|
||||
# self.assertEqual(adapted_request.rid, "test-id")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
Reference in New Issue
Block a user