Files
sglang/test/srt/openai_server/basic/test_serving_chat.py

601 lines
22 KiB
Python

"""
Unit-tests for OpenAIServingChat — rewritten to use only the std-lib 'unittest'.
Run with either:
python tests/test_serving_chat_unit.py -v
or
python -m unittest discover -s tests -p "test_*unit.py" -v
"""
import asyncio
import json
import unittest
import uuid
from typing import Optional
from unittest.mock import Mock, patch
from fastapi import Request
from sglang.srt.entrypoints.openai.protocol import (
ChatCompletionRequest,
MessageProcessingResult,
)
from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat
from sglang.srt.managers.io_struct import GenerateReqInput
class _MockTokenizerManager:
"""Minimal mock that satisfies OpenAIServingChat."""
def __init__(self):
self.model_config = Mock(is_multimodal=False)
self.server_args = Mock(
enable_cache_report=False,
tool_call_parser="hermes",
reasoning_parser=None,
)
self.chat_template_name: Optional[str] = "llama-3"
# tokenizer stub
self.tokenizer = Mock()
self.tokenizer.encode.return_value = [1, 2, 3, 4, 5]
self.tokenizer.decode.return_value = "Test response"
self.tokenizer.chat_template = None
self.tokenizer.bos_token_id = 1
# async generator stub for generate_request
async def _mock_generate():
yield {
"text": "Test response",
"meta_info": {
"id": f"chatcmpl-{uuid.uuid4()}",
"prompt_tokens": 10,
"completion_tokens": 5,
"cached_tokens": 0,
"finish_reason": {"type": "stop", "matched": None},
"output_token_logprobs": [(0.1, 1, "Test"), (0.2, 2, "response")],
"output_top_logprobs": None,
},
"index": 0,
}
self.generate_request = Mock(return_value=_mock_generate())
self.create_abort_task = Mock()
class _MockTemplateManager:
"""Minimal mock for TemplateManager."""
def __init__(self):
self.chat_template_name: Optional[str] = "llama-3"
self.jinja_template_content_format: Optional[str] = None
self.completion_template_name: Optional[str] = None
class ServingChatTestCase(unittest.TestCase):
# ------------- common fixtures -------------
def setUp(self):
self.tm = _MockTokenizerManager()
self.template_manager = _MockTemplateManager()
self.chat = OpenAIServingChat(self.tm, self.template_manager)
# frequently reused requests
self.basic_req = ChatCompletionRequest(
model="x",
messages=[{"role": "user", "content": "Hi?"}],
temperature=0.7,
max_tokens=100,
stream=False,
)
self.stream_req = ChatCompletionRequest(
model="x",
messages=[{"role": "user", "content": "Hi?"}],
temperature=0.7,
max_tokens=100,
stream=True,
)
self.fastapi_request = Mock(spec=Request)
self.fastapi_request.headers = {}
# ------------- conversion tests -------------
def test_convert_to_internal_request_single(self):
with patch(
"sglang.srt.entrypoints.openai.serving_chat.generate_chat_conv"
) as conv_mock, patch.object(self.chat, "_process_messages") as proc_mock:
conv_ins = Mock()
conv_ins.get_prompt.return_value = "Test prompt"
conv_ins.image_data = conv_ins.audio_data = None
conv_ins.modalities = []
conv_ins.stop_str = ["</s>"]
conv_mock.return_value = conv_ins
proc_mock.return_value = MessageProcessingResult(
"Test prompt",
[1, 2, 3],
None,
None,
[],
["</s>"],
None,
)
adapted, processed = self.chat._convert_to_internal_request(self.basic_req)
self.assertIsInstance(adapted, GenerateReqInput)
self.assertFalse(adapted.stream)
self.assertEqual(processed, self.basic_req)
def test_stop_str_isolation_between_requests(self):
"""Test that stop strings from one request don't affect subsequent requests.
This tests the fix for the bug where conv.stop_str was being mutated globally,
causing stop strings from one request to persist in subsequent requests.
"""
# Mock conversation template with initial stop_str
initial_stop_str = ["\n"]
with patch(
"sglang.srt.entrypoints.openai.serving_chat.generate_chat_conv"
) as conv_mock:
# Create a mock conversation object that will be returned by generate_chat_conv
conv_ins = Mock()
conv_ins.get_prompt.return_value = "Test prompt"
conv_ins.image_data = None
conv_ins.audio_data = None
conv_ins.modalities = []
conv_ins.stop_str = (
initial_stop_str.copy()
) # Template's default stop strings
conv_mock.return_value = conv_ins
# First request with additional stop string
req1 = ChatCompletionRequest(
model="x",
messages=[{"role": "user", "content": "First request"}],
stop=["CUSTOM_STOP"],
)
# Call the actual _apply_conversation_template method (not mocked)
result1 = self.chat._apply_conversation_template(req1, is_multimodal=False)
# Verify first request has both stop strings
expected_stop1 = initial_stop_str + ["CUSTOM_STOP"]
self.assertEqual(result1.stop, expected_stop1)
# Verify the original template's stop_str wasn't mutated after first request
self.assertEqual(conv_ins.stop_str, initial_stop_str)
# Second request without additional stop string
req2 = ChatCompletionRequest(
model="x",
messages=[{"role": "user", "content": "Second request"}],
# No custom stop strings
)
result2 = self.chat._apply_conversation_template(req2, is_multimodal=False)
# Verify second request only has original stop strings (no CUSTOM_STOP from req1)
self.assertEqual(result2.stop, initial_stop_str)
self.assertNotIn("CUSTOM_STOP", result2.stop)
self.assertEqual(conv_ins.stop_str, initial_stop_str)
# ------------- sampling-params -------------
def test_sampling_param_build(self):
req = ChatCompletionRequest(
model="x",
messages=[{"role": "user", "content": "Hi"}],
temperature=0.8,
max_tokens=150,
min_tokens=5,
top_p=0.9,
stop=["</s>"],
)
with patch.object(
self.chat,
"_process_messages",
return_value=("Prompt", [1], None, None, [], ["</s>"], None),
):
params = self.chat._build_sampling_params(req, ["</s>"], None)
self.assertEqual(params["temperature"], 0.8)
self.assertEqual(params["max_new_tokens"], 150)
self.assertEqual(params["min_new_tokens"], 5)
self.assertEqual(params["stop"], ["</s>"])
async def test_unstreamed_tool_args_completion(self):
"""Test that remaining tool call arguments are sent when generation finishes."""
# Mock FunctionCallParser with detector that has partial tool call data
mock_parser = Mock()
mock_detector = Mock()
# Simulate a tool call that was partially streamed
mock_detector.prev_tool_call_arr = [
{
"name": "get_weather",
"arguments": {"location": "San Francisco", "unit": "celsius"},
}
]
mock_detector.streamed_args_for_tool = [
'{"location": "San Francisco"' # Partial arguments streamed so far
]
mock_parser.detector = mock_detector
content = {
"meta_info": {
"id": "chatcmpl-test123",
}
}
request = ChatCompletionRequest(
model="test",
messages=[{"role": "user", "content": "What's the weather?"}],
tools=[{"type": "function", "function": {"name": "get_weather"}}],
)
# Test the completion method
result = self.chat._check_for_unstreamed_tool_args(
parser=mock_parser,
content=content,
request=request,
finish_reason_type="stop",
index=0,
)
# Should return a chunk with remaining arguments
self.assertIsNotNone(result, "Should return chunk with remaining arguments")
self.assertIn('"arguments":', result, "Should contain arguments field")
self.assertIn(
', "unit": "celsius"}', result, "Should contain remaining arguments"
)
self.assertIn(
'"finish_reason":null',
result,
"Should not include finish_reason in completion chunk",
)
async def test_unstreamed_tool_args_no_completion_needed(self):
"""Test that no completion chunk is sent when all arguments were already streamed."""
# Mock FunctionCallParser with detector that has complete tool call data
mock_parser = Mock()
mock_detector = Mock()
# Simulate a tool call that was completely streamed
mock_detector.prev_tool_call_arr = [
{"name": "get_weather", "arguments": {"location": "San Francisco"}}
]
mock_detector.streamed_args_for_tool = [
'{"location": "San Francisco"}' # All arguments already streamed
]
mock_parser.detector = mock_detector
content = {
"meta_info": {
"id": "chatcmpl-test123",
}
}
request = ChatCompletionRequest(
model="test",
messages=[{"role": "user", "content": "What's the weather?"}],
tools=[{"type": "function", "function": {"name": "get_weather"}}],
)
# Test the completion method
result = self.chat._check_for_unstreamed_tool_args(
parser=mock_parser,
content=content,
request=request,
finish_reason_type="stop",
index=0,
)
# Should return None since no completion is needed
self.assertIsNone(result, "Should return None when no completion is needed")
async def test_unstreamed_tool_args_no_parser_data(self):
"""Test that no completion chunk is sent when parser has no tool call data."""
# Mock FunctionCallParser with empty detector
mock_parser = Mock()
mock_detector = Mock()
mock_detector.prev_tool_call_arr = []
mock_detector.streamed_args_for_tool = []
mock_parser.detector = mock_detector
content = {
"meta_info": {
"id": "chatcmpl-test123",
}
}
request = ChatCompletionRequest(
model="test",
messages=[{"role": "user", "content": "What's the weather?"}],
tools=[{"type": "function", "function": {"name": "get_weather"}}],
)
# Test the completion method
result = self.chat._check_for_unstreamed_tool_args(
parser=mock_parser,
content=content,
request=request,
finish_reason_type="stop",
index=0,
)
# Should return None since there's no parser data
self.assertIsNone(
result, "Should return None when parser has no tool call data"
)
# ------------- kimi_k2 tool_call_id formatting -------------
def test_kimi_k2_non_streaming_tool_call_id_format(self):
"""Ensure non-streaming tool_call.id matches functions.{name}:{index} for kimi_k2 parser."""
# Force kimi_k2 parser
self.chat.tool_call_parser = "kimi_k2"
# Mock FunctionCallParser.parse_non_stream to return one tool call
with patch(
"sglang.srt.entrypoints.openai.serving_chat.FunctionCallParser"
) as ParserMock:
parser_instance = ParserMock.return_value
# Build a mock ToolCallItem-like object
call_info = Mock()
call_info.name = "get_weather"
call_info.parameters = '{"city":"Paris"}'
call_info.tool_index = 0
parser_instance.has_tool_call.return_value = True
parser_instance.parse_non_stream.return_value = ("", [call_info])
finish_reason = {"type": "stop", "matched": None}
tools = [
{"type": "function", "function": {"name": "get_weather"}},
]
tool_calls, remaining_text, finish_reason = self.chat._process_tool_calls(
text="<|tool_calls_section_begin|>...",
tools=tools,
finish_reason=finish_reason,
)
self.assertIsNotNone(tool_calls)
self.assertEqual(len(tool_calls), 1)
self.assertEqual(tool_calls[0].id, "functions.get_weather:0")
self.assertEqual(tool_calls[0].function.name, "get_weather")
def test_kimi_k2_streaming_tool_call_id_format(self):
"""Ensure streaming first chunk tool_call.id matches functions.{name}:{index} for kimi_k2 parser."""
# Force kimi_k2 parser
self.chat.tool_call_parser = "kimi_k2"
# Prepare request with tools
req = ChatCompletionRequest(
model="x",
messages=[{"role": "user", "content": "Hi?"}],
tools=[{"type": "function", "function": {"name": "get_weather"}}],
stream=True,
)
# Patch FunctionCallParser used inside _process_tool_call_stream
with patch(
"sglang.srt.entrypoints.openai.serving_chat.FunctionCallParser"
) as ParserMock:
parser_instance = ParserMock.return_value
# First call returns one ToolCallItem-like chunk (with name)
first_chunk_call = Mock()
first_chunk_call.tool_index = 0
first_chunk_call.name = "get_weather"
first_chunk_call.parameters = ""
parser_instance.parse_stream_chunk.side_effect = [
("", [first_chunk_call]),
("", []),
]
async def collect_first_tool_chunk():
gen = self.chat._process_tool_call_stream(
index=0,
delta="irrelevant",
parser_dict={},
content={"meta_info": {"id": "chatcmpl-test"}},
request=req,
has_tool_calls={},
)
# Get first yielded SSE line
line = None
async for emitted in gen:
line = emitted
break
return line
loop = asyncio.get_event_loop()
line = loop.run_until_complete(collect_first_tool_chunk())
self.assertIsNotNone(line)
self.assertTrue(line.startswith("data: "))
payload = json.loads(line[len("data: ") :])
tool_calls = payload["choices"][0]["delta"]["tool_calls"]
self.assertEqual(tool_calls[0]["id"], "functions.get_weather:0")
def test_kimi_k2_non_streaming_tool_call_id_with_history(self):
"""Ensure non-streaming tool_call.id increase with tool calls history for kimi_k2 parser."""
# Force kimi_k2 parser
self.chat.tool_call_parser = "kimi_k2"
# Prepare request with tool calls history
req = ChatCompletionRequest(
model="x",
messages=[
{"role": "user", "content": "What's the weather today in paris?"},
{
"role": "assistant",
"content": "Let me do some search first.",
"tool_calls": [
{
"id": "functions.get_weather:0",
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"city": "Paris"}',
},
}
],
},
{
"role": "tool",
"content": "It's rainy in paris now.",
"tool_call_id": "functions.get_weather:0",
},
{
"role": "assistant",
"content": "It's rainy now.",
},
{
"role": "user",
"content": "What about LA and Tokyo?",
},
],
tools=[{"type": "function", "function": {"name": "get_weather"}}],
stream=False,
)
# Mock FunctionCallParser.parse_non_stream to return one tool call
with patch(
"sglang.srt.entrypoints.openai.serving_chat.FunctionCallParser"
) as ParserMock:
parser_instance = ParserMock.return_value
# Build a mock ToolCallItem-like object
call_info = Mock()
call_info.name = "get_weather"
call_info.parameters = '{"city":"Loa Angeles"}'
# Kimi-K2 series models might generate fixed number tool_indx,
# ignoring the tool calls history and mess up all the following tool calls
call_info.tool_index = 0
call_info2 = Mock()
call_info2.name = "get_weather"
call_info2.parameters = '{"city":"Tokyo"}'
call_info2.tool_index = 1
parser_instance.has_tool_call.return_value = True
parser_instance.parse_non_stream.return_value = (
"",
[call_info, call_info2],
)
finish_reason = {"type": "stop", "matched": None}
tools = [
{"type": "function", "function": {"name": "get_weather"}},
]
history_tool_calls_cnt = self.chat._get_history_tool_calls_cnt(req)
tool_calls, remaining_text, _ = self.chat._process_tool_calls(
text="<|tool_calls_section_begin|>...",
tools=tools,
finish_reason=finish_reason,
history_tool_calls_cnt=history_tool_calls_cnt,
)
self.assertEqual(history_tool_calls_cnt, 1)
self.assertIsNotNone(tool_calls)
self.assertEqual(len(tool_calls), 2)
self.assertEqual(tool_calls[0].id, "functions.get_weather:1")
self.assertEqual(tool_calls[0].function.name, "get_weather")
self.assertEqual(tool_calls[1].id, "functions.get_weather:2")
self.assertEqual(tool_calls[1].function.name, "get_weather")
def test_kimi_k2_streaming_tool_call_id_with_history(self):
"""Ensure streaming first chunk tool_call.id increase with tool calls history for kimi_k2 parser."""
# Force kimi_k2 parser
self.chat.tool_call_parser = "kimi_k2"
# Prepare request with tool calls history
req = ChatCompletionRequest(
model="x",
messages=[
{"role": "user", "content": "What's the weather today in paris?"},
{
"role": "assistant",
"content": "Let me do some search first.",
"tool_calls": [
{
"id": "functions.get_weather:0",
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"city": "Paris"}',
},
}
],
},
{
"role": "tool",
"content": "It's rainy in paris now.",
"tool_call_id": "functions.get_weather:0",
},
{
"role": "assistant",
"content": "It's rainy now.",
},
{
"role": "user",
"content": "What about LA?",
},
],
tools=[{"type": "function", "function": {"name": "get_weather"}}],
stream=True,
)
# Patch FunctionCallParser used inside _process_tool_call_stream
with patch(
"sglang.srt.entrypoints.openai.serving_chat.FunctionCallParser"
) as ParserMock:
parser_instance = ParserMock.return_value
# First call returns one ToolCallItem-like chunk (with name)
first_chunk_call = Mock()
# Kimi-K2 series models might generate fixed number tool_indx,
# ignoring the tool calls history and mess up all the following tool calls
first_chunk_call.tool_index = 0
first_chunk_call.name = "get_weather"
first_chunk_call.parameters = ""
parser_instance.parse_stream_chunk.side_effect = [
("", [first_chunk_call]),
("", []),
]
async def collect_first_tool_chunk():
gen = self.chat._process_tool_call_stream(
index=0,
delta="irrelevant",
parser_dict={},
content={"meta_info": {"id": "chatcmpl-test"}},
request=req,
has_tool_calls={},
)
# Get first yielded SSE line
line = None
async for emitted in gen:
line = emitted
break
return line
loop = asyncio.get_event_loop()
line = loop.run_until_complete(collect_first_tool_chunk())
self.assertIsNotNone(line)
self.assertTrue(line.startswith("data: "))
payload = json.loads(line[len("data: ") :])
tool_calls = payload["choices"][0]["delta"]["tool_calls"]
self.assertEqual(tool_calls[0]["id"], "functions.get_weather:1")
if __name__ == "__main__":
unittest.main(verbosity=2)