Files
sglang/test/srt/openai/test_serving_chat.py
Xinyuan Tong 70c471a868 [Refactor] OAI Server components (#7167)
Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
2025-06-16 20:45:20 -07:00

635 lines
21 KiB
Python

"""
Unit tests for the OpenAIServingChat class from serving_chat.py.
These tests ensure that the refactored implementation maintains compatibility
with the original adapter.py functionality.
"""
import uuid
from unittest.mock import Mock, patch
import pytest
from fastapi import Request
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest, ErrorResponse
from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat
from sglang.srt.managers.io_struct import GenerateReqInput
# Mock TokenizerManager since it may not be directly importable in tests
class MockTokenizerManager:
def __init__(self):
self.model_config = Mock()
self.model_config.is_multimodal = False
self.server_args = Mock()
self.server_args.enable_cache_report = False
self.server_args.tool_call_parser = "hermes"
self.server_args.reasoning_parser = None
self.chat_template_name = "llama-3"
# Mock tokenizer
self.tokenizer = Mock()
self.tokenizer.encode = Mock(return_value=[1, 2, 3, 4, 5])
self.tokenizer.decode = Mock(return_value="Test response")
self.tokenizer.chat_template = None
self.tokenizer.bos_token_id = 1
# Mock generate_request method
async def mock_generate():
yield {
"text": "Test response",
"meta_info": {
"id": f"chatcmpl-{uuid.uuid4()}",
"prompt_tokens": 10,
"completion_tokens": 5,
"cached_tokens": 0,
"finish_reason": {"type": "stop", "matched": None},
"output_token_logprobs": [(0.1, 1, "Test"), (0.2, 2, "response")],
"output_top_logprobs": None,
},
"index": 0,
}
self.generate_request = Mock(return_value=mock_generate())
self.create_abort_task = Mock(return_value=None)
@pytest.fixture
def mock_tokenizer_manager():
"""Create a mock tokenizer manager for testing."""
return MockTokenizerManager()
@pytest.fixture
def serving_chat(mock_tokenizer_manager):
"""Create a OpenAIServingChat instance for testing."""
return OpenAIServingChat(mock_tokenizer_manager)
@pytest.fixture
def mock_request():
"""Create a mock FastAPI request."""
request = Mock(spec=Request)
request.headers = {}
return request
@pytest.fixture
def basic_chat_request():
"""Create a basic chat completion request."""
return ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=0.7,
max_tokens=100,
stream=False,
)
@pytest.fixture
def streaming_chat_request():
"""Create a streaming chat completion request."""
return ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=0.7,
max_tokens=100,
stream=True,
)
class TestOpenAIServingChatConversion:
"""Test request conversion methods."""
def test_convert_to_internal_request_single(
self, serving_chat, basic_chat_request, mock_tokenizer_manager
):
"""Test converting single request to internal format."""
with patch(
"sglang.srt.entrypoints.openai.serving_chat.generate_chat_conv"
) as mock_conv:
mock_conv_instance = Mock()
mock_conv_instance.get_prompt.return_value = "Test prompt"
mock_conv_instance.image_data = None
mock_conv_instance.audio_data = None
mock_conv_instance.modalities = []
mock_conv_instance.stop_str = ["</s>"]
mock_conv.return_value = mock_conv_instance
# Mock the _process_messages method to return expected values
with patch.object(serving_chat, "_process_messages") as mock_process:
mock_process.return_value = (
"Test prompt",
[1, 2, 3],
None,
None,
[],
["</s>"],
None, # tool_call_constraint
)
adapted_request, processed_request = (
serving_chat._convert_to_internal_request(
[basic_chat_request], ["test-id"]
)
)
assert isinstance(adapted_request, GenerateReqInput)
assert adapted_request.stream == basic_chat_request.stream
assert processed_request == basic_chat_request
class TestToolCalls:
"""Test tool call functionality from adapter.py"""
def test_tool_call_request_conversion(self, serving_chat):
"""Test request with tool calls"""
request = ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "What's the weather?"}],
tools=[
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather information",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
},
},
}
],
tool_choice="auto",
)
with patch.object(serving_chat, "_process_messages") as mock_process:
mock_process.return_value = (
"Test prompt",
[1, 2, 3],
None,
None,
[],
["</s>"],
None, # tool_call_constraint
)
adapted_request, _ = serving_chat._convert_to_internal_request(
[request], ["test-id"]
)
assert adapted_request.rid == "test-id"
# Tool call constraint should be processed
assert request.tools is not None
def test_tool_choice_none(self, serving_chat):
"""Test tool_choice=none disables tool calls"""
request = ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Hello"}],
tools=[{"type": "function", "function": {"name": "test_func"}}],
tool_choice="none",
)
with patch.object(serving_chat, "_process_messages") as mock_process:
mock_process.return_value = (
"Test prompt",
[1, 2, 3],
None,
None,
[],
["</s>"],
None, # tool_call_constraint
)
adapted_request, _ = serving_chat._convert_to_internal_request(
[request], ["test-id"]
)
# Tools should not be processed when tool_choice is "none"
assert adapted_request.rid == "test-id"
def test_tool_call_response_processing(self, serving_chat):
"""Test processing tool calls in response"""
mock_ret_item = {
"text": '{"name": "get_weather", "parameters": {"location": "Paris"}}',
"meta_info": {
"output_token_logprobs": [],
"output_top_logprobs": None,
},
}
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
},
},
}
]
finish_reason = {"type": "stop", "matched": None}
# Mock FunctionCallParser
with patch(
"sglang.srt.entrypoints.openai.serving_chat.FunctionCallParser"
) as mock_parser_class:
mock_parser = Mock()
mock_parser.has_tool_call.return_value = True
# Create proper mock tool call object
mock_tool_call = Mock()
mock_tool_call.name = "get_weather"
mock_tool_call.parameters = '{"location": "Paris"}'
mock_parser.parse_non_stream.return_value = ("", [mock_tool_call])
mock_parser_class.return_value = mock_parser
tool_calls, text, updated_finish_reason = serving_chat._process_tool_calls(
mock_ret_item["text"], tools, "hermes", finish_reason
)
assert tool_calls is not None
assert len(tool_calls) == 1
assert updated_finish_reason["type"] == "tool_calls"
class TestMultimodalContent:
"""Test multimodal content handling from adapter.py"""
def test_multimodal_request_with_images(self, serving_chat):
"""Test request with image content"""
request = ChatCompletionRequest(
model="test-model",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,..."},
},
],
}
],
)
# Set multimodal mode
serving_chat.tokenizer_manager.model_config.is_multimodal = True
with patch.object(serving_chat, "_apply_jinja_template") as mock_apply:
mock_apply.return_value = (
"prompt",
[1, 2, 3],
["image_data"],
None,
[],
[],
)
with patch.object(
serving_chat, "_apply_conversation_template"
) as mock_conv:
mock_conv.return_value = ("prompt", ["image_data"], None, [], [])
(
prompt,
prompt_ids,
image_data,
audio_data,
modalities,
stop,
tool_call_constraint,
) = serving_chat._process_messages(request, True)
assert image_data == ["image_data"]
assert prompt == "prompt"
def test_multimodal_request_with_audio(self, serving_chat):
"""Test request with audio content"""
request = ChatCompletionRequest(
model="test-model",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "Transcribe this audio"},
{
"type": "audio_url",
"audio_url": {"url": "data:audio/wav;base64,UklGR..."},
},
],
}
],
)
serving_chat.tokenizer_manager.model_config.is_multimodal = True
with patch.object(serving_chat, "_apply_jinja_template") as mock_apply:
mock_apply.return_value = (
"prompt",
[1, 2, 3],
None,
["audio_data"],
["audio"],
[],
)
with patch.object(
serving_chat, "_apply_conversation_template"
) as mock_conv:
mock_conv.return_value = ("prompt", None, ["audio_data"], ["audio"], [])
(
prompt,
prompt_ids,
image_data,
audio_data,
modalities,
stop,
tool_call_constraint,
) = serving_chat._process_messages(request, True)
assert audio_data == ["audio_data"]
assert modalities == ["audio"]
class TestTemplateHandling:
"""Test chat template handling from adapter.py"""
def test_jinja_template_processing(self, serving_chat):
"""Test Jinja template processing"""
request = ChatCompletionRequest(
model="test-model", messages=[{"role": "user", "content": "Hello"}]
)
# Mock the template attribute directly
serving_chat.tokenizer_manager.chat_template_name = None
serving_chat.tokenizer_manager.tokenizer.chat_template = "<jinja_template>"
with patch.object(serving_chat, "_apply_jinja_template") as mock_apply:
mock_apply.return_value = (
"processed_prompt",
[1, 2, 3],
None,
None,
[],
["</s>"],
)
# Mock hasattr to simulate the None check
with patch("builtins.hasattr") as mock_hasattr:
mock_hasattr.return_value = True
(
prompt,
prompt_ids,
image_data,
audio_data,
modalities,
stop,
tool_call_constraint,
) = serving_chat._process_messages(request, False)
assert prompt == "processed_prompt"
assert prompt_ids == [1, 2, 3]
def test_conversation_template_processing(self, serving_chat):
"""Test conversation template processing"""
request = ChatCompletionRequest(
model="test-model", messages=[{"role": "user", "content": "Hello"}]
)
serving_chat.tokenizer_manager.chat_template_name = "llama-3"
with patch.object(serving_chat, "_apply_conversation_template") as mock_apply:
mock_apply.return_value = ("conv_prompt", None, None, [], ["</s>"])
(
prompt,
prompt_ids,
image_data,
audio_data,
modalities,
stop,
tool_call_constraint,
) = serving_chat._process_messages(request, False)
assert prompt == "conv_prompt"
assert stop == ["</s>"]
def test_continue_final_message(self, serving_chat):
"""Test continue_final_message functionality"""
request = ChatCompletionRequest(
model="test-model",
messages=[
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there"},
],
continue_final_message=True,
)
with patch.object(serving_chat, "_apply_conversation_template") as mock_apply:
mock_apply.return_value = ("Hi there", None, None, [], ["</s>"])
(
prompt,
prompt_ids,
image_data,
audio_data,
modalities,
stop,
tool_call_constraint,
) = serving_chat._process_messages(request, False)
# Should handle continue_final_message properly
assert prompt == "Hi there"
class TestReasoningContent:
"""Test reasoning content separation from adapter.py"""
def test_reasoning_content_request(self, serving_chat):
"""Test request with reasoning content separation"""
request = ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Solve this math problem"}],
separate_reasoning=True,
stream_reasoning=False,
)
with patch.object(serving_chat, "_process_messages") as mock_process:
mock_process.return_value = (
"Test prompt",
[1, 2, 3],
None,
None,
[],
["</s>"],
None, # tool_call_constraint
)
adapted_request, _ = serving_chat._convert_to_internal_request(
[request], ["test-id"]
)
assert adapted_request.rid == "test-id"
assert request.separate_reasoning == True
def test_reasoning_content_response(self, serving_chat):
"""Test reasoning content in response"""
mock_ret_item = {
"text": "<thinking>This is reasoning</thinking>Answer: 42",
"meta_info": {
"output_token_logprobs": [],
"output_top_logprobs": None,
},
}
# Mock ReasoningParser
with patch(
"sglang.srt.entrypoints.openai.serving_chat.ReasoningParser"
) as mock_parser_class:
mock_parser = Mock()
mock_parser.parse_non_stream.return_value = (
"This is reasoning",
"Answer: 42",
)
mock_parser_class.return_value = mock_parser
choice_logprobs = None
reasoning_text = None
text = mock_ret_item["text"]
# Simulate reasoning processing
enable_thinking = True
if enable_thinking:
parser = mock_parser_class(model_type="test", stream_reasoning=False)
reasoning_text, text = parser.parse_non_stream(text)
assert reasoning_text == "This is reasoning"
assert text == "Answer: 42"
class TestSamplingParams:
"""Test sampling parameter handling from adapter.py"""
def test_all_sampling_parameters(self, serving_chat):
"""Test all sampling parameters are properly handled"""
request = ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Hello"}],
temperature=0.8,
max_tokens=150,
max_completion_tokens=200,
min_tokens=5,
top_p=0.9,
top_k=50,
min_p=0.1,
presence_penalty=0.1,
frequency_penalty=0.2,
repetition_penalty=1.1,
stop=["<|endoftext|>"],
stop_token_ids=[13, 14],
regex=r"\d+",
ebnf="<expr> ::= <number>",
n=2,
no_stop_trim=True,
ignore_eos=True,
skip_special_tokens=False,
logit_bias={"1": 0.5, "2": -0.3},
)
with patch.object(serving_chat, "_process_messages") as mock_process:
mock_process.return_value = (
"Test prompt",
[1, 2, 3],
None,
None,
[],
["</s>"],
None, # tool_call_constraint
)
sampling_params = serving_chat._build_sampling_params(
request, ["</s>"], None
)
# Verify all parameters
assert sampling_params["temperature"] == 0.8
assert sampling_params["max_new_tokens"] == 150
assert sampling_params["min_new_tokens"] == 5
assert sampling_params["top_p"] == 0.9
assert sampling_params["top_k"] == 50
assert sampling_params["min_p"] == 0.1
assert sampling_params["presence_penalty"] == 0.1
assert sampling_params["frequency_penalty"] == 0.2
assert sampling_params["repetition_penalty"] == 1.1
assert sampling_params["stop"] == ["</s>"]
assert sampling_params["logit_bias"] == {"1": 0.5, "2": -0.3}
def test_response_format_json_schema(self, serving_chat):
"""Test response format with JSON schema"""
request = ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Generate JSON"}],
response_format={
"type": "json_schema",
"json_schema": {
"name": "response",
"schema": {
"type": "object",
"properties": {"answer": {"type": "string"}},
},
},
},
)
with patch.object(serving_chat, "_process_messages") as mock_process:
mock_process.return_value = (
"Test prompt",
[1, 2, 3],
None,
None,
[],
["</s>"],
None, # tool_call_constraint
)
sampling_params = serving_chat._build_sampling_params(
request, ["</s>"], None
)
assert "json_schema" in sampling_params
assert '"type": "object"' in sampling_params["json_schema"]
def test_response_format_json_object(self, serving_chat):
"""Test response format with JSON object"""
request = ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Generate JSON"}],
response_format={"type": "json_object"},
)
with patch.object(serving_chat, "_process_messages") as mock_process:
mock_process.return_value = (
"Test prompt",
[1, 2, 3],
None,
None,
[],
["</s>"],
None, # tool_call_constraint
)
sampling_params = serving_chat._build_sampling_params(
request, ["</s>"], None
)
assert sampling_params["json_schema"] == '{"type": "object"}'