601 lines
22 KiB
Python
601 lines
22 KiB
Python
"""
|
|
Unit-tests for OpenAIServingChat — rewritten to use only the std-lib 'unittest'.
|
|
Run with either:
|
|
python tests/test_serving_chat_unit.py -v
|
|
or
|
|
python -m unittest discover -s tests -p "test_*unit.py" -v
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import unittest
|
|
import uuid
|
|
from typing import Optional
|
|
from unittest.mock import Mock, patch
|
|
|
|
from fastapi import Request
|
|
|
|
from sglang.srt.entrypoints.openai.protocol import (
|
|
ChatCompletionRequest,
|
|
MessageProcessingResult,
|
|
)
|
|
from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat
|
|
from sglang.srt.managers.io_struct import GenerateReqInput
|
|
|
|
|
|
class _MockTokenizerManager:
|
|
"""Minimal mock that satisfies OpenAIServingChat."""
|
|
|
|
def __init__(self):
|
|
self.model_config = Mock(is_multimodal=False)
|
|
self.server_args = Mock(
|
|
enable_cache_report=False,
|
|
tool_call_parser="hermes",
|
|
reasoning_parser=None,
|
|
)
|
|
self.chat_template_name: Optional[str] = "llama-3"
|
|
|
|
# tokenizer stub
|
|
self.tokenizer = Mock()
|
|
self.tokenizer.encode.return_value = [1, 2, 3, 4, 5]
|
|
self.tokenizer.decode.return_value = "Test response"
|
|
self.tokenizer.chat_template = None
|
|
self.tokenizer.bos_token_id = 1
|
|
|
|
# async generator stub for generate_request
|
|
async def _mock_generate():
|
|
yield {
|
|
"text": "Test response",
|
|
"meta_info": {
|
|
"id": f"chatcmpl-{uuid.uuid4()}",
|
|
"prompt_tokens": 10,
|
|
"completion_tokens": 5,
|
|
"cached_tokens": 0,
|
|
"finish_reason": {"type": "stop", "matched": None},
|
|
"output_token_logprobs": [(0.1, 1, "Test"), (0.2, 2, "response")],
|
|
"output_top_logprobs": None,
|
|
},
|
|
"index": 0,
|
|
}
|
|
|
|
self.generate_request = Mock(return_value=_mock_generate())
|
|
self.create_abort_task = Mock()
|
|
|
|
|
|
class _MockTemplateManager:
|
|
"""Minimal mock for TemplateManager."""
|
|
|
|
def __init__(self):
|
|
self.chat_template_name: Optional[str] = "llama-3"
|
|
self.jinja_template_content_format: Optional[str] = None
|
|
self.completion_template_name: Optional[str] = None
|
|
|
|
|
|
class ServingChatTestCase(unittest.TestCase):
|
|
# ------------- common fixtures -------------
|
|
def setUp(self):
|
|
self.tm = _MockTokenizerManager()
|
|
self.template_manager = _MockTemplateManager()
|
|
self.chat = OpenAIServingChat(self.tm, self.template_manager)
|
|
|
|
# frequently reused requests
|
|
self.basic_req = ChatCompletionRequest(
|
|
model="x",
|
|
messages=[{"role": "user", "content": "Hi?"}],
|
|
temperature=0.7,
|
|
max_tokens=100,
|
|
stream=False,
|
|
)
|
|
self.stream_req = ChatCompletionRequest(
|
|
model="x",
|
|
messages=[{"role": "user", "content": "Hi?"}],
|
|
temperature=0.7,
|
|
max_tokens=100,
|
|
stream=True,
|
|
)
|
|
|
|
self.fastapi_request = Mock(spec=Request)
|
|
self.fastapi_request.headers = {}
|
|
|
|
# ------------- conversion tests -------------
|
|
def test_convert_to_internal_request_single(self):
|
|
with patch(
|
|
"sglang.srt.entrypoints.openai.serving_chat.generate_chat_conv"
|
|
) as conv_mock, patch.object(self.chat, "_process_messages") as proc_mock:
|
|
conv_ins = Mock()
|
|
conv_ins.get_prompt.return_value = "Test prompt"
|
|
conv_ins.image_data = conv_ins.audio_data = None
|
|
conv_ins.modalities = []
|
|
conv_ins.stop_str = ["</s>"]
|
|
conv_mock.return_value = conv_ins
|
|
|
|
proc_mock.return_value = MessageProcessingResult(
|
|
"Test prompt",
|
|
[1, 2, 3],
|
|
None,
|
|
None,
|
|
[],
|
|
["</s>"],
|
|
None,
|
|
)
|
|
|
|
adapted, processed = self.chat._convert_to_internal_request(self.basic_req)
|
|
self.assertIsInstance(adapted, GenerateReqInput)
|
|
self.assertFalse(adapted.stream)
|
|
self.assertEqual(processed, self.basic_req)
|
|
|
|
def test_stop_str_isolation_between_requests(self):
|
|
"""Test that stop strings from one request don't affect subsequent requests.
|
|
|
|
This tests the fix for the bug where conv.stop_str was being mutated globally,
|
|
causing stop strings from one request to persist in subsequent requests.
|
|
"""
|
|
# Mock conversation template with initial stop_str
|
|
initial_stop_str = ["\n"]
|
|
|
|
with patch(
|
|
"sglang.srt.entrypoints.openai.serving_chat.generate_chat_conv"
|
|
) as conv_mock:
|
|
# Create a mock conversation object that will be returned by generate_chat_conv
|
|
conv_ins = Mock()
|
|
conv_ins.get_prompt.return_value = "Test prompt"
|
|
conv_ins.image_data = None
|
|
conv_ins.audio_data = None
|
|
conv_ins.modalities = []
|
|
conv_ins.stop_str = (
|
|
initial_stop_str.copy()
|
|
) # Template's default stop strings
|
|
conv_mock.return_value = conv_ins
|
|
|
|
# First request with additional stop string
|
|
req1 = ChatCompletionRequest(
|
|
model="x",
|
|
messages=[{"role": "user", "content": "First request"}],
|
|
stop=["CUSTOM_STOP"],
|
|
)
|
|
|
|
# Call the actual _apply_conversation_template method (not mocked)
|
|
result1 = self.chat._apply_conversation_template(req1, is_multimodal=False)
|
|
|
|
# Verify first request has both stop strings
|
|
expected_stop1 = initial_stop_str + ["CUSTOM_STOP"]
|
|
self.assertEqual(result1.stop, expected_stop1)
|
|
|
|
# Verify the original template's stop_str wasn't mutated after first request
|
|
self.assertEqual(conv_ins.stop_str, initial_stop_str)
|
|
|
|
# Second request without additional stop string
|
|
req2 = ChatCompletionRequest(
|
|
model="x",
|
|
messages=[{"role": "user", "content": "Second request"}],
|
|
# No custom stop strings
|
|
)
|
|
result2 = self.chat._apply_conversation_template(req2, is_multimodal=False)
|
|
|
|
# Verify second request only has original stop strings (no CUSTOM_STOP from req1)
|
|
self.assertEqual(result2.stop, initial_stop_str)
|
|
self.assertNotIn("CUSTOM_STOP", result2.stop)
|
|
self.assertEqual(conv_ins.stop_str, initial_stop_str)
|
|
|
|
# ------------- sampling-params -------------
|
|
def test_sampling_param_build(self):
|
|
req = ChatCompletionRequest(
|
|
model="x",
|
|
messages=[{"role": "user", "content": "Hi"}],
|
|
temperature=0.8,
|
|
max_tokens=150,
|
|
min_tokens=5,
|
|
top_p=0.9,
|
|
stop=["</s>"],
|
|
)
|
|
with patch.object(
|
|
self.chat,
|
|
"_process_messages",
|
|
return_value=("Prompt", [1], None, None, [], ["</s>"], None),
|
|
):
|
|
params = self.chat._build_sampling_params(req, ["</s>"], None)
|
|
self.assertEqual(params["temperature"], 0.8)
|
|
self.assertEqual(params["max_new_tokens"], 150)
|
|
self.assertEqual(params["min_new_tokens"], 5)
|
|
self.assertEqual(params["stop"], ["</s>"])
|
|
|
|
async def test_unstreamed_tool_args_completion(self):
|
|
"""Test that remaining tool call arguments are sent when generation finishes."""
|
|
|
|
# Mock FunctionCallParser with detector that has partial tool call data
|
|
mock_parser = Mock()
|
|
mock_detector = Mock()
|
|
|
|
# Simulate a tool call that was partially streamed
|
|
mock_detector.prev_tool_call_arr = [
|
|
{
|
|
"name": "get_weather",
|
|
"arguments": {"location": "San Francisco", "unit": "celsius"},
|
|
}
|
|
]
|
|
mock_detector.streamed_args_for_tool = [
|
|
'{"location": "San Francisco"' # Partial arguments streamed so far
|
|
]
|
|
mock_parser.detector = mock_detector
|
|
|
|
content = {
|
|
"meta_info": {
|
|
"id": "chatcmpl-test123",
|
|
}
|
|
}
|
|
|
|
request = ChatCompletionRequest(
|
|
model="test",
|
|
messages=[{"role": "user", "content": "What's the weather?"}],
|
|
tools=[{"type": "function", "function": {"name": "get_weather"}}],
|
|
)
|
|
|
|
# Test the completion method
|
|
result = self.chat._check_for_unstreamed_tool_args(
|
|
parser=mock_parser,
|
|
content=content,
|
|
request=request,
|
|
finish_reason_type="stop",
|
|
index=0,
|
|
)
|
|
|
|
# Should return a chunk with remaining arguments
|
|
self.assertIsNotNone(result, "Should return chunk with remaining arguments")
|
|
self.assertIn('"arguments":', result, "Should contain arguments field")
|
|
self.assertIn(
|
|
', "unit": "celsius"}', result, "Should contain remaining arguments"
|
|
)
|
|
self.assertIn(
|
|
'"finish_reason":null',
|
|
result,
|
|
"Should not include finish_reason in completion chunk",
|
|
)
|
|
|
|
async def test_unstreamed_tool_args_no_completion_needed(self):
|
|
"""Test that no completion chunk is sent when all arguments were already streamed."""
|
|
|
|
# Mock FunctionCallParser with detector that has complete tool call data
|
|
mock_parser = Mock()
|
|
mock_detector = Mock()
|
|
|
|
# Simulate a tool call that was completely streamed
|
|
mock_detector.prev_tool_call_arr = [
|
|
{"name": "get_weather", "arguments": {"location": "San Francisco"}}
|
|
]
|
|
mock_detector.streamed_args_for_tool = [
|
|
'{"location": "San Francisco"}' # All arguments already streamed
|
|
]
|
|
mock_parser.detector = mock_detector
|
|
|
|
content = {
|
|
"meta_info": {
|
|
"id": "chatcmpl-test123",
|
|
}
|
|
}
|
|
|
|
request = ChatCompletionRequest(
|
|
model="test",
|
|
messages=[{"role": "user", "content": "What's the weather?"}],
|
|
tools=[{"type": "function", "function": {"name": "get_weather"}}],
|
|
)
|
|
|
|
# Test the completion method
|
|
result = self.chat._check_for_unstreamed_tool_args(
|
|
parser=mock_parser,
|
|
content=content,
|
|
request=request,
|
|
finish_reason_type="stop",
|
|
index=0,
|
|
)
|
|
|
|
# Should return None since no completion is needed
|
|
self.assertIsNone(result, "Should return None when no completion is needed")
|
|
|
|
async def test_unstreamed_tool_args_no_parser_data(self):
|
|
"""Test that no completion chunk is sent when parser has no tool call data."""
|
|
|
|
# Mock FunctionCallParser with empty detector
|
|
mock_parser = Mock()
|
|
mock_detector = Mock()
|
|
mock_detector.prev_tool_call_arr = []
|
|
mock_detector.streamed_args_for_tool = []
|
|
mock_parser.detector = mock_detector
|
|
|
|
content = {
|
|
"meta_info": {
|
|
"id": "chatcmpl-test123",
|
|
}
|
|
}
|
|
|
|
request = ChatCompletionRequest(
|
|
model="test",
|
|
messages=[{"role": "user", "content": "What's the weather?"}],
|
|
tools=[{"type": "function", "function": {"name": "get_weather"}}],
|
|
)
|
|
|
|
# Test the completion method
|
|
result = self.chat._check_for_unstreamed_tool_args(
|
|
parser=mock_parser,
|
|
content=content,
|
|
request=request,
|
|
finish_reason_type="stop",
|
|
index=0,
|
|
)
|
|
|
|
# Should return None since there's no parser data
|
|
self.assertIsNone(
|
|
result, "Should return None when parser has no tool call data"
|
|
)
|
|
|
|
# ------------- kimi_k2 tool_call_id formatting -------------
|
|
def test_kimi_k2_non_streaming_tool_call_id_format(self):
|
|
"""Ensure non-streaming tool_call.id matches functions.{name}:{index} for kimi_k2 parser."""
|
|
|
|
# Force kimi_k2 parser
|
|
self.chat.tool_call_parser = "kimi_k2"
|
|
|
|
# Mock FunctionCallParser.parse_non_stream to return one tool call
|
|
with patch(
|
|
"sglang.srt.entrypoints.openai.serving_chat.FunctionCallParser"
|
|
) as ParserMock:
|
|
parser_instance = ParserMock.return_value
|
|
|
|
# Build a mock ToolCallItem-like object
|
|
call_info = Mock()
|
|
call_info.name = "get_weather"
|
|
call_info.parameters = '{"city":"Paris"}'
|
|
call_info.tool_index = 0
|
|
|
|
parser_instance.has_tool_call.return_value = True
|
|
parser_instance.parse_non_stream.return_value = ("", [call_info])
|
|
|
|
finish_reason = {"type": "stop", "matched": None}
|
|
tools = [
|
|
{"type": "function", "function": {"name": "get_weather"}},
|
|
]
|
|
|
|
tool_calls, remaining_text, finish_reason = self.chat._process_tool_calls(
|
|
text="<|tool_calls_section_begin|>...",
|
|
tools=tools,
|
|
finish_reason=finish_reason,
|
|
)
|
|
|
|
self.assertIsNotNone(tool_calls)
|
|
self.assertEqual(len(tool_calls), 1)
|
|
self.assertEqual(tool_calls[0].id, "functions.get_weather:0")
|
|
self.assertEqual(tool_calls[0].function.name, "get_weather")
|
|
|
|
def test_kimi_k2_streaming_tool_call_id_format(self):
|
|
"""Ensure streaming first chunk tool_call.id matches functions.{name}:{index} for kimi_k2 parser."""
|
|
|
|
# Force kimi_k2 parser
|
|
self.chat.tool_call_parser = "kimi_k2"
|
|
|
|
# Prepare request with tools
|
|
req = ChatCompletionRequest(
|
|
model="x",
|
|
messages=[{"role": "user", "content": "Hi?"}],
|
|
tools=[{"type": "function", "function": {"name": "get_weather"}}],
|
|
stream=True,
|
|
)
|
|
|
|
# Patch FunctionCallParser used inside _process_tool_call_stream
|
|
with patch(
|
|
"sglang.srt.entrypoints.openai.serving_chat.FunctionCallParser"
|
|
) as ParserMock:
|
|
parser_instance = ParserMock.return_value
|
|
|
|
# First call returns one ToolCallItem-like chunk (with name)
|
|
first_chunk_call = Mock()
|
|
first_chunk_call.tool_index = 0
|
|
first_chunk_call.name = "get_weather"
|
|
first_chunk_call.parameters = ""
|
|
parser_instance.parse_stream_chunk.side_effect = [
|
|
("", [first_chunk_call]),
|
|
("", []),
|
|
]
|
|
|
|
async def collect_first_tool_chunk():
|
|
gen = self.chat._process_tool_call_stream(
|
|
index=0,
|
|
delta="irrelevant",
|
|
parser_dict={},
|
|
content={"meta_info": {"id": "chatcmpl-test"}},
|
|
request=req,
|
|
has_tool_calls={},
|
|
)
|
|
# Get first yielded SSE line
|
|
line = None
|
|
async for emitted in gen:
|
|
line = emitted
|
|
break
|
|
return line
|
|
|
|
loop = asyncio.get_event_loop()
|
|
line = loop.run_until_complete(collect_first_tool_chunk())
|
|
self.assertIsNotNone(line)
|
|
self.assertTrue(line.startswith("data: "))
|
|
|
|
payload = json.loads(line[len("data: ") :])
|
|
tool_calls = payload["choices"][0]["delta"]["tool_calls"]
|
|
self.assertEqual(tool_calls[0]["id"], "functions.get_weather:0")
|
|
|
|
def test_kimi_k2_non_streaming_tool_call_id_with_history(self):
|
|
"""Ensure non-streaming tool_call.id increase with tool calls history for kimi_k2 parser."""
|
|
|
|
# Force kimi_k2 parser
|
|
self.chat.tool_call_parser = "kimi_k2"
|
|
|
|
# Prepare request with tool calls history
|
|
req = ChatCompletionRequest(
|
|
model="x",
|
|
messages=[
|
|
{"role": "user", "content": "What's the weather today in paris?"},
|
|
{
|
|
"role": "assistant",
|
|
"content": "Let me do some search first.",
|
|
"tool_calls": [
|
|
{
|
|
"id": "functions.get_weather:0",
|
|
"type": "function",
|
|
"function": {
|
|
"name": "get_weather",
|
|
"arguments": '{"city": "Paris"}',
|
|
},
|
|
}
|
|
],
|
|
},
|
|
{
|
|
"role": "tool",
|
|
"content": "It's rainy in paris now.",
|
|
"tool_call_id": "functions.get_weather:0",
|
|
},
|
|
{
|
|
"role": "assistant",
|
|
"content": "It's rainy now.",
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": "What about LA and Tokyo?",
|
|
},
|
|
],
|
|
tools=[{"type": "function", "function": {"name": "get_weather"}}],
|
|
stream=False,
|
|
)
|
|
|
|
# Mock FunctionCallParser.parse_non_stream to return one tool call
|
|
with patch(
|
|
"sglang.srt.entrypoints.openai.serving_chat.FunctionCallParser"
|
|
) as ParserMock:
|
|
parser_instance = ParserMock.return_value
|
|
|
|
# Build a mock ToolCallItem-like object
|
|
call_info = Mock()
|
|
call_info.name = "get_weather"
|
|
call_info.parameters = '{"city":"Loa Angeles"}'
|
|
# Kimi-K2 series models might generate fixed number tool_indx,
|
|
# ignoring the tool calls history and mess up all the following tool calls
|
|
call_info.tool_index = 0
|
|
|
|
call_info2 = Mock()
|
|
call_info2.name = "get_weather"
|
|
call_info2.parameters = '{"city":"Tokyo"}'
|
|
call_info2.tool_index = 1
|
|
|
|
parser_instance.has_tool_call.return_value = True
|
|
parser_instance.parse_non_stream.return_value = (
|
|
"",
|
|
[call_info, call_info2],
|
|
)
|
|
|
|
finish_reason = {"type": "stop", "matched": None}
|
|
tools = [
|
|
{"type": "function", "function": {"name": "get_weather"}},
|
|
]
|
|
|
|
history_tool_calls_cnt = self.chat._get_history_tool_calls_cnt(req)
|
|
tool_calls, remaining_text, _ = self.chat._process_tool_calls(
|
|
text="<|tool_calls_section_begin|>...",
|
|
tools=tools,
|
|
finish_reason=finish_reason,
|
|
history_tool_calls_cnt=history_tool_calls_cnt,
|
|
)
|
|
|
|
self.assertEqual(history_tool_calls_cnt, 1)
|
|
self.assertIsNotNone(tool_calls)
|
|
self.assertEqual(len(tool_calls), 2)
|
|
self.assertEqual(tool_calls[0].id, "functions.get_weather:1")
|
|
self.assertEqual(tool_calls[0].function.name, "get_weather")
|
|
self.assertEqual(tool_calls[1].id, "functions.get_weather:2")
|
|
self.assertEqual(tool_calls[1].function.name, "get_weather")
|
|
|
|
def test_kimi_k2_streaming_tool_call_id_with_history(self):
|
|
"""Ensure streaming first chunk tool_call.id increase with tool calls history for kimi_k2 parser."""
|
|
|
|
# Force kimi_k2 parser
|
|
self.chat.tool_call_parser = "kimi_k2"
|
|
|
|
# Prepare request with tool calls history
|
|
req = ChatCompletionRequest(
|
|
model="x",
|
|
messages=[
|
|
{"role": "user", "content": "What's the weather today in paris?"},
|
|
{
|
|
"role": "assistant",
|
|
"content": "Let me do some search first.",
|
|
"tool_calls": [
|
|
{
|
|
"id": "functions.get_weather:0",
|
|
"type": "function",
|
|
"function": {
|
|
"name": "get_weather",
|
|
"arguments": '{"city": "Paris"}',
|
|
},
|
|
}
|
|
],
|
|
},
|
|
{
|
|
"role": "tool",
|
|
"content": "It's rainy in paris now.",
|
|
"tool_call_id": "functions.get_weather:0",
|
|
},
|
|
{
|
|
"role": "assistant",
|
|
"content": "It's rainy now.",
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": "What about LA?",
|
|
},
|
|
],
|
|
tools=[{"type": "function", "function": {"name": "get_weather"}}],
|
|
stream=True,
|
|
)
|
|
|
|
# Patch FunctionCallParser used inside _process_tool_call_stream
|
|
with patch(
|
|
"sglang.srt.entrypoints.openai.serving_chat.FunctionCallParser"
|
|
) as ParserMock:
|
|
parser_instance = ParserMock.return_value
|
|
|
|
# First call returns one ToolCallItem-like chunk (with name)
|
|
first_chunk_call = Mock()
|
|
# Kimi-K2 series models might generate fixed number tool_indx,
|
|
# ignoring the tool calls history and mess up all the following tool calls
|
|
first_chunk_call.tool_index = 0
|
|
first_chunk_call.name = "get_weather"
|
|
first_chunk_call.parameters = ""
|
|
parser_instance.parse_stream_chunk.side_effect = [
|
|
("", [first_chunk_call]),
|
|
("", []),
|
|
]
|
|
|
|
async def collect_first_tool_chunk():
|
|
gen = self.chat._process_tool_call_stream(
|
|
index=0,
|
|
delta="irrelevant",
|
|
parser_dict={},
|
|
content={"meta_info": {"id": "chatcmpl-test"}},
|
|
request=req,
|
|
has_tool_calls={},
|
|
)
|
|
# Get first yielded SSE line
|
|
line = None
|
|
async for emitted in gen:
|
|
line = emitted
|
|
break
|
|
return line
|
|
|
|
loop = asyncio.get_event_loop()
|
|
line = loop.run_until_complete(collect_first_tool_chunk())
|
|
self.assertIsNotNone(line)
|
|
self.assertTrue(line.startswith("data: "))
|
|
|
|
payload = json.loads(line[len("data: ") :])
|
|
tool_calls = payload["choices"][0]["delta"]["tool_calls"]
|
|
self.assertEqual(tool_calls[0]["id"], "functions.get_weather:1")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main(verbosity=2)
|