[Refactor] OAI Server components (#7167)

Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
Xinyuan Tong
2025-06-16 20:45:20 -07:00
committed by GitHub
parent 1a9c2c9214
commit 70c471a868
12 changed files with 4424 additions and 0 deletions

View File

@@ -0,0 +1,683 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for OpenAI API protocol models"""
import json
import time
from typing import Dict, List, Optional
import pytest
from pydantic import ValidationError
from sglang.srt.entrypoints.openai.protocol import (
BatchRequest,
BatchResponse,
ChatCompletionMessageContentImagePart,
ChatCompletionMessageContentTextPart,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
ChatCompletionTokenLogprob,
ChatMessage,
ChoiceLogprobs,
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
DeltaMessage,
EmbeddingObject,
EmbeddingRequest,
EmbeddingResponse,
ErrorResponse,
FileDeleteResponse,
FileRequest,
FileResponse,
Function,
FunctionResponse,
JsonSchemaResponseFormat,
LogProbs,
ModelCard,
ModelList,
MultimodalEmbeddingInput,
ResponseFormat,
ScoringRequest,
ScoringResponse,
StreamOptions,
StructuralTagResponseFormat,
Tool,
ToolCall,
ToolChoice,
TopLogprob,
UsageInfo,
)
class TestModelCard:
"""Test ModelCard protocol model"""
def test_basic_model_card_creation(self):
"""Test basic model card creation with required fields"""
card = ModelCard(id="test-model")
assert card.id == "test-model"
assert card.object == "model"
assert card.owned_by == "sglang"
assert isinstance(card.created, int)
assert card.root is None
assert card.max_model_len is None
def test_model_card_with_optional_fields(self):
"""Test model card with optional fields"""
card = ModelCard(
id="test-model",
root="/path/to/model",
max_model_len=2048,
created=1234567890,
)
assert card.id == "test-model"
assert card.root == "/path/to/model"
assert card.max_model_len == 2048
assert card.created == 1234567890
def test_model_card_serialization(self):
"""Test model card JSON serialization"""
card = ModelCard(id="test-model", max_model_len=4096)
data = card.model_dump()
assert data["id"] == "test-model"
assert data["object"] == "model"
assert data["max_model_len"] == 4096
class TestModelList:
"""Test ModelList protocol model"""
def test_empty_model_list(self):
"""Test empty model list creation"""
model_list = ModelList()
assert model_list.object == "list"
assert len(model_list.data) == 0
def test_model_list_with_cards(self):
"""Test model list with model cards"""
cards = [
ModelCard(id="model-1"),
ModelCard(id="model-2", max_model_len=2048),
]
model_list = ModelList(data=cards)
assert len(model_list.data) == 2
assert model_list.data[0].id == "model-1"
assert model_list.data[1].id == "model-2"
class TestErrorResponse:
"""Test ErrorResponse protocol model"""
def test_basic_error_response(self):
"""Test basic error response creation"""
error = ErrorResponse(
message="Invalid request", type="BadRequestError", code=400
)
assert error.object == "error"
assert error.message == "Invalid request"
assert error.type == "BadRequestError"
assert error.code == 400
assert error.param is None
def test_error_response_with_param(self):
"""Test error response with parameter"""
error = ErrorResponse(
message="Invalid temperature",
type="ValidationError",
code=422,
param="temperature",
)
assert error.param == "temperature"
class TestUsageInfo:
"""Test UsageInfo protocol model"""
def test_basic_usage_info(self):
"""Test basic usage info creation"""
usage = UsageInfo(prompt_tokens=10, completion_tokens=20, total_tokens=30)
assert usage.prompt_tokens == 10
assert usage.completion_tokens == 20
assert usage.total_tokens == 30
assert usage.prompt_tokens_details is None
def test_usage_info_with_cache_details(self):
"""Test usage info with cache details"""
usage = UsageInfo(
prompt_tokens=10,
completion_tokens=20,
total_tokens=30,
prompt_tokens_details={"cached_tokens": 5},
)
assert usage.prompt_tokens_details == {"cached_tokens": 5}
class TestCompletionRequest:
"""Test CompletionRequest protocol model"""
def test_basic_completion_request(self):
"""Test basic completion request"""
request = CompletionRequest(model="test-model", prompt="Hello world")
assert request.model == "test-model"
assert request.prompt == "Hello world"
assert request.max_tokens == 16 # default
assert request.temperature == 1.0 # default
assert request.n == 1 # default
assert not request.stream # default
assert not request.echo # default
def test_completion_request_with_options(self):
"""Test completion request with various options"""
request = CompletionRequest(
model="test-model",
prompt=["Hello", "world"],
max_tokens=100,
temperature=0.7,
top_p=0.9,
n=2,
stream=True,
echo=True,
stop=[".", "!"],
logprobs=5,
)
assert request.prompt == ["Hello", "world"]
assert request.max_tokens == 100
assert request.temperature == 0.7
assert request.top_p == 0.9
assert request.n == 2
assert request.stream
assert request.echo
assert request.stop == [".", "!"]
assert request.logprobs == 5
def test_completion_request_sglang_extensions(self):
"""Test completion request with SGLang-specific extensions"""
request = CompletionRequest(
model="test-model",
prompt="Hello",
top_k=50,
min_p=0.1,
repetition_penalty=1.1,
regex=r"\d+",
json_schema='{"type": "object"}',
lora_path="/path/to/lora",
)
assert request.top_k == 50
assert request.min_p == 0.1
assert request.repetition_penalty == 1.1
assert request.regex == r"\d+"
assert request.json_schema == '{"type": "object"}'
assert request.lora_path == "/path/to/lora"
def test_completion_request_validation_errors(self):
"""Test completion request validation errors"""
with pytest.raises(ValidationError):
CompletionRequest() # missing required fields
with pytest.raises(ValidationError):
CompletionRequest(model="test-model") # missing prompt
class TestCompletionResponse:
"""Test CompletionResponse protocol model"""
def test_basic_completion_response(self):
"""Test basic completion response"""
choice = CompletionResponseChoice(
index=0, text="Hello world!", finish_reason="stop"
)
usage = UsageInfo(prompt_tokens=2, completion_tokens=3, total_tokens=5)
response = CompletionResponse(
id="test-id", model="test-model", choices=[choice], usage=usage
)
assert response.id == "test-id"
assert response.object == "text_completion"
assert response.model == "test-model"
assert len(response.choices) == 1
assert response.choices[0].text == "Hello world!"
assert response.usage.total_tokens == 5
class TestChatCompletionRequest:
"""Test ChatCompletionRequest protocol model"""
def test_basic_chat_completion_request(self):
"""Test basic chat completion request"""
messages = [{"role": "user", "content": "Hello"}]
request = ChatCompletionRequest(model="test-model", messages=messages)
assert request.model == "test-model"
assert len(request.messages) == 1
assert request.messages[0].role == "user"
assert request.messages[0].content == "Hello"
assert request.temperature == 0.7 # default
assert not request.stream # default
assert request.tool_choice == "none" # default when no tools
def test_chat_completion_with_multimodal_content(self):
"""Test chat completion with multimodal content"""
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,/9j/4AAQ..."},
},
],
}
]
request = ChatCompletionRequest(model="test-model", messages=messages)
assert len(request.messages[0].content) == 2
assert request.messages[0].content[0].type == "text"
assert request.messages[0].content[1].type == "image_url"
def test_chat_completion_with_tools(self):
"""Test chat completion with tools"""
messages = [{"role": "user", "content": "What's the weather?"}]
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather information",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
},
},
}
]
request = ChatCompletionRequest(
model="test-model", messages=messages, tools=tools
)
assert len(request.tools) == 1
assert request.tools[0].function.name == "get_weather"
assert request.tool_choice == "auto" # default when tools present
def test_chat_completion_tool_choice_validation(self):
"""Test tool choice validation logic"""
messages = [{"role": "user", "content": "Hello"}]
# No tools, tool_choice should default to "none"
request1 = ChatCompletionRequest(model="test-model", messages=messages)
assert request1.tool_choice == "none"
# With tools, tool_choice should default to "auto"
tools = [
{
"type": "function",
"function": {"name": "test_func", "description": "Test function"},
}
]
request2 = ChatCompletionRequest(
model="test-model", messages=messages, tools=tools
)
assert request2.tool_choice == "auto"
def test_chat_completion_sglang_extensions(self):
"""Test chat completion with SGLang extensions"""
messages = [{"role": "user", "content": "Hello"}]
request = ChatCompletionRequest(
model="test-model",
messages=messages,
top_k=40,
min_p=0.05,
separate_reasoning=False,
stream_reasoning=False,
chat_template_kwargs={"custom_param": "value"},
)
assert request.top_k == 40
assert request.min_p == 0.05
assert not request.separate_reasoning
assert not request.stream_reasoning
assert request.chat_template_kwargs == {"custom_param": "value"}
class TestChatCompletionResponse:
"""Test ChatCompletionResponse protocol model"""
def test_basic_chat_completion_response(self):
"""Test basic chat completion response"""
message = ChatMessage(role="assistant", content="Hello there!")
choice = ChatCompletionResponseChoice(
index=0, message=message, finish_reason="stop"
)
usage = UsageInfo(prompt_tokens=2, completion_tokens=3, total_tokens=5)
response = ChatCompletionResponse(
id="test-id", model="test-model", choices=[choice], usage=usage
)
assert response.id == "test-id"
assert response.object == "chat.completion"
assert response.model == "test-model"
assert len(response.choices) == 1
assert response.choices[0].message.content == "Hello there!"
def test_chat_completion_response_with_tool_calls(self):
"""Test chat completion response with tool calls"""
tool_call = ToolCall(
id="call_123",
function=FunctionResponse(
name="get_weather", arguments='{"location": "San Francisco"}'
),
)
message = ChatMessage(role="assistant", content=None, tool_calls=[tool_call])
choice = ChatCompletionResponseChoice(
index=0, message=message, finish_reason="tool_calls"
)
usage = UsageInfo(prompt_tokens=10, completion_tokens=5, total_tokens=15)
response = ChatCompletionResponse(
id="test-id", model="test-model", choices=[choice], usage=usage
)
assert response.choices[0].message.tool_calls[0].function.name == "get_weather"
assert response.choices[0].finish_reason == "tool_calls"
class TestEmbeddingRequest:
"""Test EmbeddingRequest protocol model"""
def test_basic_embedding_request(self):
"""Test basic embedding request"""
request = EmbeddingRequest(model="test-model", input="Hello world")
assert request.model == "test-model"
assert request.input == "Hello world"
assert request.encoding_format == "float" # default
assert request.dimensions is None # default
def test_embedding_request_with_list_input(self):
"""Test embedding request with list input"""
request = EmbeddingRequest(
model="test-model", input=["Hello", "world"], dimensions=512
)
assert request.input == ["Hello", "world"]
assert request.dimensions == 512
def test_multimodal_embedding_request(self):
"""Test multimodal embedding request"""
multimodal_input = [
MultimodalEmbeddingInput(text="Hello", image="base64_image_data"),
MultimodalEmbeddingInput(text="World", image=None),
]
request = EmbeddingRequest(model="test-model", input=multimodal_input)
assert len(request.input) == 2
assert request.input[0].text == "Hello"
assert request.input[0].image == "base64_image_data"
assert request.input[1].text == "World"
assert request.input[1].image is None
class TestEmbeddingResponse:
"""Test EmbeddingResponse protocol model"""
def test_basic_embedding_response(self):
"""Test basic embedding response"""
embedding_obj = EmbeddingObject(embedding=[0.1, 0.2, 0.3], index=0)
usage = UsageInfo(prompt_tokens=3, total_tokens=3)
response = EmbeddingResponse(
data=[embedding_obj], model="test-model", usage=usage
)
assert response.object == "list"
assert len(response.data) == 1
assert response.data[0].embedding == [0.1, 0.2, 0.3]
assert response.data[0].index == 0
assert response.usage.prompt_tokens == 3
class TestScoringRequest:
"""Test ScoringRequest protocol model"""
def test_basic_scoring_request(self):
"""Test basic scoring request"""
request = ScoringRequest(
model="test-model", query="Hello", items=["World", "Earth"]
)
assert request.model == "test-model"
assert request.query == "Hello"
assert request.items == ["World", "Earth"]
assert not request.apply_softmax # default
assert not request.item_first # default
def test_scoring_request_with_token_ids(self):
"""Test scoring request with token IDs"""
request = ScoringRequest(
model="test-model",
query=[1, 2, 3],
items=[[4, 5], [6, 7]],
label_token_ids=[8, 9],
apply_softmax=True,
item_first=True,
)
assert request.query == [1, 2, 3]
assert request.items == [[4, 5], [6, 7]]
assert request.label_token_ids == [8, 9]
assert request.apply_softmax
assert request.item_first
class TestScoringResponse:
"""Test ScoringResponse protocol model"""
def test_basic_scoring_response(self):
"""Test basic scoring response"""
response = ScoringResponse(scores=[[0.1, 0.9], [0.3, 0.7]], model="test-model")
assert response.object == "scoring"
assert response.scores == [[0.1, 0.9], [0.3, 0.7]]
assert response.model == "test-model"
assert response.usage is None # default
class TestFileOperations:
"""Test file operation protocol models"""
def test_file_request(self):
"""Test file request model"""
file_data = b"test file content"
request = FileRequest(file=file_data, purpose="batch")
assert request.file == file_data
assert request.purpose == "batch"
def test_file_response(self):
"""Test file response model"""
response = FileResponse(
id="file-123",
bytes=1024,
created_at=1234567890,
filename="test.jsonl",
purpose="batch",
)
assert response.id == "file-123"
assert response.object == "file"
assert response.bytes == 1024
assert response.filename == "test.jsonl"
def test_file_delete_response(self):
"""Test file delete response model"""
response = FileDeleteResponse(id="file-123", deleted=True)
assert response.id == "file-123"
assert response.object == "file"
assert response.deleted
class TestBatchOperations:
"""Test batch operation protocol models"""
def test_batch_request(self):
"""Test batch request model"""
request = BatchRequest(
input_file_id="file-123",
endpoint="/v1/chat/completions",
completion_window="24h",
metadata={"custom": "value"},
)
assert request.input_file_id == "file-123"
assert request.endpoint == "/v1/chat/completions"
assert request.completion_window == "24h"
assert request.metadata == {"custom": "value"}
def test_batch_response(self):
"""Test batch response model"""
response = BatchResponse(
id="batch-123",
endpoint="/v1/chat/completions",
input_file_id="file-123",
completion_window="24h",
created_at=1234567890,
)
assert response.id == "batch-123"
assert response.object == "batch"
assert response.status == "validating" # default
assert response.endpoint == "/v1/chat/completions"
class TestResponseFormats:
"""Test response format protocol models"""
def test_basic_response_format(self):
"""Test basic response format"""
format_obj = ResponseFormat(type="json_object")
assert format_obj.type == "json_object"
assert format_obj.json_schema is None
def test_json_schema_response_format(self):
"""Test JSON schema response format"""
schema = {"type": "object", "properties": {"name": {"type": "string"}}}
json_schema = JsonSchemaResponseFormat(
name="person_schema", description="Person schema", schema=schema
)
format_obj = ResponseFormat(type="json_schema", json_schema=json_schema)
assert format_obj.type == "json_schema"
assert format_obj.json_schema.name == "person_schema"
assert format_obj.json_schema.schema_ == schema
def test_structural_tag_response_format(self):
"""Test structural tag response format"""
structures = [
{
"begin": "<thinking>",
"schema_": {"type": "string"},
"end": "</thinking>",
}
]
format_obj = StructuralTagResponseFormat(
type="structural_tag", structures=structures, triggers=["think"]
)
assert format_obj.type == "structural_tag"
assert len(format_obj.structures) == 1
assert format_obj.triggers == ["think"]
class TestLogProbs:
"""Test LogProbs protocol models"""
def test_basic_logprobs(self):
"""Test basic LogProbs model"""
logprobs = LogProbs(
text_offset=[0, 5, 11],
token_logprobs=[-0.1, -0.2, -0.3],
tokens=["Hello", " ", "world"],
top_logprobs=[{"Hello": -0.1}, {" ": -0.2}, {"world": -0.3}],
)
assert len(logprobs.tokens) == 3
assert logprobs.tokens == ["Hello", " ", "world"]
assert logprobs.token_logprobs == [-0.1, -0.2, -0.3]
def test_choice_logprobs(self):
"""Test ChoiceLogprobs model"""
token_logprob = ChatCompletionTokenLogprob(
token="Hello",
bytes=[72, 101, 108, 108, 111],
logprob=-0.1,
top_logprobs=[
TopLogprob(token="Hello", bytes=[72, 101, 108, 108, 111], logprob=-0.1)
],
)
choice_logprobs = ChoiceLogprobs(content=[token_logprob])
assert len(choice_logprobs.content) == 1
assert choice_logprobs.content[0].token == "Hello"
class TestStreamingModels:
"""Test streaming response models"""
def test_stream_options(self):
"""Test StreamOptions model"""
options = StreamOptions(include_usage=True)
assert options.include_usage
def test_chat_completion_stream_response(self):
"""Test ChatCompletionStreamResponse model"""
delta = DeltaMessage(role="assistant", content="Hello")
choice = ChatCompletionResponseStreamChoice(index=0, delta=delta)
response = ChatCompletionStreamResponse(
id="test-id", model="test-model", choices=[choice]
)
assert response.object == "chat.completion.chunk"
assert response.choices[0].delta.content == "Hello"
class TestValidationEdgeCases:
"""Test edge cases and validation scenarios"""
def test_empty_messages_validation(self):
"""Test validation with empty messages"""
with pytest.raises(ValidationError):
ChatCompletionRequest(model="test-model", messages=[])
def test_invalid_tool_choice_type(self):
"""Test invalid tool choice type"""
messages = [{"role": "user", "content": "Hello"}]
with pytest.raises(ValidationError):
ChatCompletionRequest(
model="test-model", messages=messages, tool_choice=123
)
def test_negative_token_limits(self):
"""Test negative token limits"""
with pytest.raises(ValidationError):
CompletionRequest(model="test-model", prompt="Hello", max_tokens=-1)
def test_invalid_temperature_range(self):
"""Test invalid temperature values"""
# Note: The current protocol doesn't enforce temperature range,
# but this test documents expected behavior
request = CompletionRequest(model="test-model", prompt="Hello", temperature=5.0)
assert request.temperature == 5.0 # Currently allowed
def test_model_serialization_roundtrip(self):
"""Test that models can be serialized and deserialized"""
original_request = ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Hello"}],
temperature=0.7,
max_tokens=100,
)
# Serialize to dict
data = original_request.model_dump()
# Deserialize back
restored_request = ChatCompletionRequest(**data)
assert restored_request.model == original_request.model
assert restored_request.temperature == original_request.temperature
assert restored_request.max_tokens == original_request.max_tokens
assert len(restored_request.messages) == len(original_request.messages)
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -0,0 +1,634 @@
"""
Unit tests for the OpenAIServingChat class from serving_chat.py.
These tests ensure that the refactored implementation maintains compatibility
with the original adapter.py functionality.
"""
import uuid
from unittest.mock import Mock, patch
import pytest
from fastapi import Request
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest, ErrorResponse
from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat
from sglang.srt.managers.io_struct import GenerateReqInput
# Mock TokenizerManager since it may not be directly importable in tests
class MockTokenizerManager:
def __init__(self):
self.model_config = Mock()
self.model_config.is_multimodal = False
self.server_args = Mock()
self.server_args.enable_cache_report = False
self.server_args.tool_call_parser = "hermes"
self.server_args.reasoning_parser = None
self.chat_template_name = "llama-3"
# Mock tokenizer
self.tokenizer = Mock()
self.tokenizer.encode = Mock(return_value=[1, 2, 3, 4, 5])
self.tokenizer.decode = Mock(return_value="Test response")
self.tokenizer.chat_template = None
self.tokenizer.bos_token_id = 1
# Mock generate_request method
async def mock_generate():
yield {
"text": "Test response",
"meta_info": {
"id": f"chatcmpl-{uuid.uuid4()}",
"prompt_tokens": 10,
"completion_tokens": 5,
"cached_tokens": 0,
"finish_reason": {"type": "stop", "matched": None},
"output_token_logprobs": [(0.1, 1, "Test"), (0.2, 2, "response")],
"output_top_logprobs": None,
},
"index": 0,
}
self.generate_request = Mock(return_value=mock_generate())
self.create_abort_task = Mock(return_value=None)
@pytest.fixture
def mock_tokenizer_manager():
"""Create a mock tokenizer manager for testing."""
return MockTokenizerManager()
@pytest.fixture
def serving_chat(mock_tokenizer_manager):
"""Create a OpenAIServingChat instance for testing."""
return OpenAIServingChat(mock_tokenizer_manager)
@pytest.fixture
def mock_request():
"""Create a mock FastAPI request."""
request = Mock(spec=Request)
request.headers = {}
return request
@pytest.fixture
def basic_chat_request():
"""Create a basic chat completion request."""
return ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=0.7,
max_tokens=100,
stream=False,
)
@pytest.fixture
def streaming_chat_request():
"""Create a streaming chat completion request."""
return ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=0.7,
max_tokens=100,
stream=True,
)
class TestOpenAIServingChatConversion:
"""Test request conversion methods."""
def test_convert_to_internal_request_single(
self, serving_chat, basic_chat_request, mock_tokenizer_manager
):
"""Test converting single request to internal format."""
with patch(
"sglang.srt.entrypoints.openai.serving_chat.generate_chat_conv"
) as mock_conv:
mock_conv_instance = Mock()
mock_conv_instance.get_prompt.return_value = "Test prompt"
mock_conv_instance.image_data = None
mock_conv_instance.audio_data = None
mock_conv_instance.modalities = []
mock_conv_instance.stop_str = ["</s>"]
mock_conv.return_value = mock_conv_instance
# Mock the _process_messages method to return expected values
with patch.object(serving_chat, "_process_messages") as mock_process:
mock_process.return_value = (
"Test prompt",
[1, 2, 3],
None,
None,
[],
["</s>"],
None, # tool_call_constraint
)
adapted_request, processed_request = (
serving_chat._convert_to_internal_request(
[basic_chat_request], ["test-id"]
)
)
assert isinstance(adapted_request, GenerateReqInput)
assert adapted_request.stream == basic_chat_request.stream
assert processed_request == basic_chat_request
class TestToolCalls:
"""Test tool call functionality from adapter.py"""
def test_tool_call_request_conversion(self, serving_chat):
"""Test request with tool calls"""
request = ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "What's the weather?"}],
tools=[
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather information",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
},
},
}
],
tool_choice="auto",
)
with patch.object(serving_chat, "_process_messages") as mock_process:
mock_process.return_value = (
"Test prompt",
[1, 2, 3],
None,
None,
[],
["</s>"],
None, # tool_call_constraint
)
adapted_request, _ = serving_chat._convert_to_internal_request(
[request], ["test-id"]
)
assert adapted_request.rid == "test-id"
# Tool call constraint should be processed
assert request.tools is not None
def test_tool_choice_none(self, serving_chat):
"""Test tool_choice=none disables tool calls"""
request = ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Hello"}],
tools=[{"type": "function", "function": {"name": "test_func"}}],
tool_choice="none",
)
with patch.object(serving_chat, "_process_messages") as mock_process:
mock_process.return_value = (
"Test prompt",
[1, 2, 3],
None,
None,
[],
["</s>"],
None, # tool_call_constraint
)
adapted_request, _ = serving_chat._convert_to_internal_request(
[request], ["test-id"]
)
# Tools should not be processed when tool_choice is "none"
assert adapted_request.rid == "test-id"
def test_tool_call_response_processing(self, serving_chat):
"""Test processing tool calls in response"""
mock_ret_item = {
"text": '{"name": "get_weather", "parameters": {"location": "Paris"}}',
"meta_info": {
"output_token_logprobs": [],
"output_top_logprobs": None,
},
}
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
},
},
}
]
finish_reason = {"type": "stop", "matched": None}
# Mock FunctionCallParser
with patch(
"sglang.srt.entrypoints.openai.serving_chat.FunctionCallParser"
) as mock_parser_class:
mock_parser = Mock()
mock_parser.has_tool_call.return_value = True
# Create proper mock tool call object
mock_tool_call = Mock()
mock_tool_call.name = "get_weather"
mock_tool_call.parameters = '{"location": "Paris"}'
mock_parser.parse_non_stream.return_value = ("", [mock_tool_call])
mock_parser_class.return_value = mock_parser
tool_calls, text, updated_finish_reason = serving_chat._process_tool_calls(
mock_ret_item["text"], tools, "hermes", finish_reason
)
assert tool_calls is not None
assert len(tool_calls) == 1
assert updated_finish_reason["type"] == "tool_calls"
class TestMultimodalContent:
"""Test multimodal content handling from adapter.py"""
def test_multimodal_request_with_images(self, serving_chat):
"""Test request with image content"""
request = ChatCompletionRequest(
model="test-model",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,..."},
},
],
}
],
)
# Set multimodal mode
serving_chat.tokenizer_manager.model_config.is_multimodal = True
with patch.object(serving_chat, "_apply_jinja_template") as mock_apply:
mock_apply.return_value = (
"prompt",
[1, 2, 3],
["image_data"],
None,
[],
[],
)
with patch.object(
serving_chat, "_apply_conversation_template"
) as mock_conv:
mock_conv.return_value = ("prompt", ["image_data"], None, [], [])
(
prompt,
prompt_ids,
image_data,
audio_data,
modalities,
stop,
tool_call_constraint,
) = serving_chat._process_messages(request, True)
assert image_data == ["image_data"]
assert prompt == "prompt"
def test_multimodal_request_with_audio(self, serving_chat):
"""Test request with audio content"""
request = ChatCompletionRequest(
model="test-model",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "Transcribe this audio"},
{
"type": "audio_url",
"audio_url": {"url": "data:audio/wav;base64,UklGR..."},
},
],
}
],
)
serving_chat.tokenizer_manager.model_config.is_multimodal = True
with patch.object(serving_chat, "_apply_jinja_template") as mock_apply:
mock_apply.return_value = (
"prompt",
[1, 2, 3],
None,
["audio_data"],
["audio"],
[],
)
with patch.object(
serving_chat, "_apply_conversation_template"
) as mock_conv:
mock_conv.return_value = ("prompt", None, ["audio_data"], ["audio"], [])
(
prompt,
prompt_ids,
image_data,
audio_data,
modalities,
stop,
tool_call_constraint,
) = serving_chat._process_messages(request, True)
assert audio_data == ["audio_data"]
assert modalities == ["audio"]
class TestTemplateHandling:
"""Test chat template handling from adapter.py"""
def test_jinja_template_processing(self, serving_chat):
"""Test Jinja template processing"""
request = ChatCompletionRequest(
model="test-model", messages=[{"role": "user", "content": "Hello"}]
)
# Mock the template attribute directly
serving_chat.tokenizer_manager.chat_template_name = None
serving_chat.tokenizer_manager.tokenizer.chat_template = "<jinja_template>"
with patch.object(serving_chat, "_apply_jinja_template") as mock_apply:
mock_apply.return_value = (
"processed_prompt",
[1, 2, 3],
None,
None,
[],
["</s>"],
)
# Mock hasattr to simulate the None check
with patch("builtins.hasattr") as mock_hasattr:
mock_hasattr.return_value = True
(
prompt,
prompt_ids,
image_data,
audio_data,
modalities,
stop,
tool_call_constraint,
) = serving_chat._process_messages(request, False)
assert prompt == "processed_prompt"
assert prompt_ids == [1, 2, 3]
def test_conversation_template_processing(self, serving_chat):
"""Test conversation template processing"""
request = ChatCompletionRequest(
model="test-model", messages=[{"role": "user", "content": "Hello"}]
)
serving_chat.tokenizer_manager.chat_template_name = "llama-3"
with patch.object(serving_chat, "_apply_conversation_template") as mock_apply:
mock_apply.return_value = ("conv_prompt", None, None, [], ["</s>"])
(
prompt,
prompt_ids,
image_data,
audio_data,
modalities,
stop,
tool_call_constraint,
) = serving_chat._process_messages(request, False)
assert prompt == "conv_prompt"
assert stop == ["</s>"]
def test_continue_final_message(self, serving_chat):
"""Test continue_final_message functionality"""
request = ChatCompletionRequest(
model="test-model",
messages=[
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there"},
],
continue_final_message=True,
)
with patch.object(serving_chat, "_apply_conversation_template") as mock_apply:
mock_apply.return_value = ("Hi there", None, None, [], ["</s>"])
(
prompt,
prompt_ids,
image_data,
audio_data,
modalities,
stop,
tool_call_constraint,
) = serving_chat._process_messages(request, False)
# Should handle continue_final_message properly
assert prompt == "Hi there"
class TestReasoningContent:
"""Test reasoning content separation from adapter.py"""
def test_reasoning_content_request(self, serving_chat):
"""Test request with reasoning content separation"""
request = ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Solve this math problem"}],
separate_reasoning=True,
stream_reasoning=False,
)
with patch.object(serving_chat, "_process_messages") as mock_process:
mock_process.return_value = (
"Test prompt",
[1, 2, 3],
None,
None,
[],
["</s>"],
None, # tool_call_constraint
)
adapted_request, _ = serving_chat._convert_to_internal_request(
[request], ["test-id"]
)
assert adapted_request.rid == "test-id"
assert request.separate_reasoning == True
def test_reasoning_content_response(self, serving_chat):
"""Test reasoning content in response"""
mock_ret_item = {
"text": "<thinking>This is reasoning</thinking>Answer: 42",
"meta_info": {
"output_token_logprobs": [],
"output_top_logprobs": None,
},
}
# Mock ReasoningParser
with patch(
"sglang.srt.entrypoints.openai.serving_chat.ReasoningParser"
) as mock_parser_class:
mock_parser = Mock()
mock_parser.parse_non_stream.return_value = (
"This is reasoning",
"Answer: 42",
)
mock_parser_class.return_value = mock_parser
choice_logprobs = None
reasoning_text = None
text = mock_ret_item["text"]
# Simulate reasoning processing
enable_thinking = True
if enable_thinking:
parser = mock_parser_class(model_type="test", stream_reasoning=False)
reasoning_text, text = parser.parse_non_stream(text)
assert reasoning_text == "This is reasoning"
assert text == "Answer: 42"
class TestSamplingParams:
"""Test sampling parameter handling from adapter.py"""
def test_all_sampling_parameters(self, serving_chat):
"""Test all sampling parameters are properly handled"""
request = ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Hello"}],
temperature=0.8,
max_tokens=150,
max_completion_tokens=200,
min_tokens=5,
top_p=0.9,
top_k=50,
min_p=0.1,
presence_penalty=0.1,
frequency_penalty=0.2,
repetition_penalty=1.1,
stop=["<|endoftext|>"],
stop_token_ids=[13, 14],
regex=r"\d+",
ebnf="<expr> ::= <number>",
n=2,
no_stop_trim=True,
ignore_eos=True,
skip_special_tokens=False,
logit_bias={"1": 0.5, "2": -0.3},
)
with patch.object(serving_chat, "_process_messages") as mock_process:
mock_process.return_value = (
"Test prompt",
[1, 2, 3],
None,
None,
[],
["</s>"],
None, # tool_call_constraint
)
sampling_params = serving_chat._build_sampling_params(
request, ["</s>"], None
)
# Verify all parameters
assert sampling_params["temperature"] == 0.8
assert sampling_params["max_new_tokens"] == 150
assert sampling_params["min_new_tokens"] == 5
assert sampling_params["top_p"] == 0.9
assert sampling_params["top_k"] == 50
assert sampling_params["min_p"] == 0.1
assert sampling_params["presence_penalty"] == 0.1
assert sampling_params["frequency_penalty"] == 0.2
assert sampling_params["repetition_penalty"] == 1.1
assert sampling_params["stop"] == ["</s>"]
assert sampling_params["logit_bias"] == {"1": 0.5, "2": -0.3}
def test_response_format_json_schema(self, serving_chat):
"""Test response format with JSON schema"""
request = ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Generate JSON"}],
response_format={
"type": "json_schema",
"json_schema": {
"name": "response",
"schema": {
"type": "object",
"properties": {"answer": {"type": "string"}},
},
},
},
)
with patch.object(serving_chat, "_process_messages") as mock_process:
mock_process.return_value = (
"Test prompt",
[1, 2, 3],
None,
None,
[],
["</s>"],
None, # tool_call_constraint
)
sampling_params = serving_chat._build_sampling_params(
request, ["</s>"], None
)
assert "json_schema" in sampling_params
assert '"type": "object"' in sampling_params["json_schema"]
def test_response_format_json_object(self, serving_chat):
"""Test response format with JSON object"""
request = ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Generate JSON"}],
response_format={"type": "json_object"},
)
with patch.object(serving_chat, "_process_messages") as mock_process:
mock_process.return_value = (
"Test prompt",
[1, 2, 3],
None,
None,
[],
["</s>"],
None, # tool_call_constraint
)
sampling_params = serving_chat._build_sampling_params(
request, ["</s>"], None
)
assert sampling_params["json_schema"] == '{"type": "object"}'

View File

@@ -0,0 +1,176 @@
"""
Tests for the refactored completions serving handler
"""
from unittest.mock import AsyncMock, Mock, patch
import pytest
from sglang.srt.entrypoints.openai.protocol import (
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
CompletionStreamResponse,
ErrorResponse,
)
from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompletion
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.tokenizer_manager import TokenizerManager
@pytest.fixture
def mock_tokenizer_manager():
"""Create a mock tokenizer manager"""
manager = Mock(spec=TokenizerManager)
# Mock tokenizer
manager.tokenizer = Mock()
manager.tokenizer.encode = Mock(return_value=[1, 2, 3, 4])
manager.tokenizer.decode = Mock(return_value="decoded text")
manager.tokenizer.bos_token_id = 1
# Mock model config
manager.model_config = Mock()
manager.model_config.is_multimodal = False
# Mock server args
manager.server_args = Mock()
manager.server_args.enable_cache_report = False
# Mock generation
manager.generate_request = AsyncMock()
manager.create_abort_task = Mock(return_value=None)
return manager
@pytest.fixture
def serving_completion(mock_tokenizer_manager):
"""Create a OpenAIServingCompletion instance"""
return OpenAIServingCompletion(mock_tokenizer_manager)
class TestPromptHandling:
"""Test different prompt types and formats from adapter.py"""
def test_single_string_prompt(self, serving_completion):
"""Test handling single string prompt"""
request = CompletionRequest(
model="test-model", prompt="Hello world", max_tokens=100
)
adapted_request, _ = serving_completion._convert_to_internal_request(
[request], ["test-id"]
)
assert adapted_request.text == "Hello world"
def test_single_token_ids_prompt(self, serving_completion):
"""Test handling single token IDs prompt"""
request = CompletionRequest(
model="test-model", prompt=[1, 2, 3, 4], max_tokens=100
)
adapted_request, _ = serving_completion._convert_to_internal_request(
[request], ["test-id"]
)
assert adapted_request.input_ids == [1, 2, 3, 4]
def test_completion_template_handling(self, serving_completion):
"""Test completion template processing"""
request = CompletionRequest(
model="test-model",
prompt="def hello():",
suffix="return 'world'",
max_tokens=100,
)
with patch(
"sglang.srt.entrypoints.openai.serving_completions.is_completion_template_defined",
return_value=True,
):
with patch(
"sglang.srt.entrypoints.openai.serving_completions.generate_completion_prompt_from_request",
return_value="processed_prompt",
):
adapted_request, _ = serving_completion._convert_to_internal_request(
[request], ["test-id"]
)
assert adapted_request.text == "processed_prompt"
class TestEchoHandling:
"""Test echo functionality from adapter.py"""
def test_echo_with_string_prompt_streaming(self, serving_completion):
"""Test echo handling with string prompt in streaming"""
request = CompletionRequest(
model="test-model", prompt="Hello", max_tokens=100, echo=True
)
# Test _get_echo_text method
echo_text = serving_completion._get_echo_text(request, 0)
assert echo_text == "Hello"
def test_echo_with_list_of_strings_streaming(self, serving_completion):
"""Test echo handling with list of strings in streaming"""
request = CompletionRequest(
model="test-model",
prompt=["Hello", "World"],
max_tokens=100,
echo=True,
n=1,
)
echo_text = serving_completion._get_echo_text(request, 0)
assert echo_text == "Hello"
echo_text = serving_completion._get_echo_text(request, 1)
assert echo_text == "World"
def test_echo_with_token_ids_streaming(self, serving_completion):
"""Test echo handling with token IDs in streaming"""
request = CompletionRequest(
model="test-model", prompt=[1, 2, 3], max_tokens=100, echo=True
)
serving_completion.tokenizer_manager.tokenizer.decode.return_value = (
"decoded_prompt"
)
echo_text = serving_completion._get_echo_text(request, 0)
assert echo_text == "decoded_prompt"
def test_echo_with_multiple_token_ids_streaming(self, serving_completion):
"""Test echo handling with multiple token ID prompts in streaming"""
request = CompletionRequest(
model="test-model", prompt=[[1, 2], [3, 4]], max_tokens=100, echo=True, n=1
)
serving_completion.tokenizer_manager.tokenizer.decode.return_value = "decoded"
echo_text = serving_completion._get_echo_text(request, 0)
assert echo_text == "decoded"
def test_prepare_echo_prompts_non_streaming(self, serving_completion):
"""Test prepare echo prompts for non-streaming response"""
# Test with single string
request = CompletionRequest(model="test-model", prompt="Hello", echo=True)
echo_prompts = serving_completion._prepare_echo_prompts(request)
assert echo_prompts == ["Hello"]
# Test with list of strings
request = CompletionRequest(
model="test-model", prompt=["Hello", "World"], echo=True
)
echo_prompts = serving_completion._prepare_echo_prompts(request)
assert echo_prompts == ["Hello", "World"]
# Test with token IDs
request = CompletionRequest(model="test-model", prompt=[1, 2, 3], echo=True)
serving_completion.tokenizer_manager.tokenizer.decode.return_value = "decoded"
echo_prompts = serving_completion._prepare_echo_prompts(request)
assert echo_prompts == ["decoded"]

View File

@@ -0,0 +1,316 @@
"""
Unit tests for the OpenAIServingEmbedding class from serving_embedding.py.
These tests ensure that the embedding serving implementation maintains compatibility
with the original adapter.py functionality and follows OpenAI API specifications.
"""
import asyncio
import json
import time
import uuid
from typing import Any, Dict, List
from unittest.mock import AsyncMock, Mock, patch
import pytest
from fastapi import Request
from fastapi.responses import ORJSONResponse
from pydantic_core import ValidationError
from sglang.srt.entrypoints.openai.protocol import (
EmbeddingObject,
EmbeddingRequest,
EmbeddingResponse,
ErrorResponse,
MultimodalEmbeddingInput,
UsageInfo,
)
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from sglang.srt.managers.io_struct import EmbeddingReqInput
# Mock TokenizerManager for embedding tests
class MockTokenizerManager:
def __init__(self):
self.model_config = Mock()
self.model_config.is_multimodal = False
self.server_args = Mock()
self.server_args.enable_cache_report = False
self.model_path = "test-model"
# Mock tokenizer
self.tokenizer = Mock()
self.tokenizer.encode = Mock(return_value=[1, 2, 3, 4, 5])
self.tokenizer.decode = Mock(return_value="Test embedding input")
self.tokenizer.chat_template = None
self.tokenizer.bos_token_id = 1
# Mock generate_request method for embeddings
async def mock_generate_embedding():
yield {
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5] * 20, # 100-dim embedding
"meta_info": {
"id": f"embd-{uuid.uuid4()}",
"prompt_tokens": 5,
},
}
self.generate_request = Mock(return_value=mock_generate_embedding())
@pytest.fixture
def mock_tokenizer_manager():
"""Create a mock tokenizer manager for testing."""
return MockTokenizerManager()
@pytest.fixture
def serving_embedding(mock_tokenizer_manager):
"""Create an OpenAIServingEmbedding instance for testing."""
return OpenAIServingEmbedding(mock_tokenizer_manager)
@pytest.fixture
def mock_request():
"""Create a mock FastAPI request."""
request = Mock(spec=Request)
request.headers = {}
return request
@pytest.fixture
def basic_embedding_request():
"""Create a basic embedding request."""
return EmbeddingRequest(
model="test-model",
input="Hello, how are you?",
encoding_format="float",
)
@pytest.fixture
def list_embedding_request():
"""Create an embedding request with list input."""
return EmbeddingRequest(
model="test-model",
input=["Hello, how are you?", "I am fine, thank you!"],
encoding_format="float",
)
@pytest.fixture
def multimodal_embedding_request():
"""Create a multimodal embedding request."""
return EmbeddingRequest(
model="test-model",
input=[
MultimodalEmbeddingInput(text="Hello", image="base64_image_data"),
MultimodalEmbeddingInput(text="World", image=None),
],
encoding_format="float",
)
@pytest.fixture
def token_ids_embedding_request():
"""Create an embedding request with token IDs."""
return EmbeddingRequest(
model="test-model",
input=[1, 2, 3, 4, 5],
encoding_format="float",
)
class TestOpenAIServingEmbeddingConversion:
"""Test request conversion methods."""
def test_convert_single_string_request(
self, serving_embedding, basic_embedding_request
):
"""Test converting single string request to internal format."""
adapted_request, processed_request = (
serving_embedding._convert_to_internal_request(
[basic_embedding_request], ["test-id"]
)
)
assert isinstance(adapted_request, EmbeddingReqInput)
assert adapted_request.text == "Hello, how are you?"
assert adapted_request.rid == "test-id"
assert processed_request == basic_embedding_request
def test_convert_list_string_request(
self, serving_embedding, list_embedding_request
):
"""Test converting list of strings request to internal format."""
adapted_request, processed_request = (
serving_embedding._convert_to_internal_request(
[list_embedding_request], ["test-id"]
)
)
assert isinstance(adapted_request, EmbeddingReqInput)
assert adapted_request.text == ["Hello, how are you?", "I am fine, thank you!"]
assert adapted_request.rid == "test-id"
assert processed_request == list_embedding_request
def test_convert_token_ids_request(
self, serving_embedding, token_ids_embedding_request
):
"""Test converting token IDs request to internal format."""
adapted_request, processed_request = (
serving_embedding._convert_to_internal_request(
[token_ids_embedding_request], ["test-id"]
)
)
assert isinstance(adapted_request, EmbeddingReqInput)
assert adapted_request.input_ids == [1, 2, 3, 4, 5]
assert adapted_request.rid == "test-id"
assert processed_request == token_ids_embedding_request
def test_convert_multimodal_request(
self, serving_embedding, multimodal_embedding_request
):
"""Test converting multimodal request to internal format."""
adapted_request, processed_request = (
serving_embedding._convert_to_internal_request(
[multimodal_embedding_request], ["test-id"]
)
)
assert isinstance(adapted_request, EmbeddingReqInput)
# Should extract text and images separately
assert len(adapted_request.text) == 2
assert "Hello" in adapted_request.text
assert "World" in adapted_request.text
assert adapted_request.image_data[0] == "base64_image_data"
assert adapted_request.image_data[1] is None
assert adapted_request.rid == "test-id"
class TestEmbeddingResponseBuilding:
"""Test response building methods."""
def test_build_single_embedding_response(self, serving_embedding):
"""Test building response for single embedding."""
ret_data = [
{
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
"meta_info": {"prompt_tokens": 5},
}
]
response = serving_embedding._build_embedding_response(ret_data, "test-model")
assert isinstance(response, EmbeddingResponse)
assert response.model == "test-model"
assert len(response.data) == 1
assert response.data[0].embedding == [0.1, 0.2, 0.3, 0.4, 0.5]
assert response.data[0].index == 0
assert response.data[0].object == "embedding"
assert response.usage.prompt_tokens == 5
assert response.usage.total_tokens == 5
assert response.usage.completion_tokens == 0
def test_build_multiple_embedding_response(self, serving_embedding):
"""Test building response for multiple embeddings."""
ret_data = [
{
"embedding": [0.1, 0.2, 0.3],
"meta_info": {"prompt_tokens": 3},
},
{
"embedding": [0.4, 0.5, 0.6],
"meta_info": {"prompt_tokens": 4},
},
]
response = serving_embedding._build_embedding_response(ret_data, "test-model")
assert isinstance(response, EmbeddingResponse)
assert len(response.data) == 2
assert response.data[0].embedding == [0.1, 0.2, 0.3]
assert response.data[0].index == 0
assert response.data[1].embedding == [0.4, 0.5, 0.6]
assert response.data[1].index == 1
assert response.usage.prompt_tokens == 7 # 3 + 4
assert response.usage.total_tokens == 7
@pytest.mark.asyncio
class TestOpenAIServingEmbeddingAsyncMethods:
"""Test async methods of OpenAIServingEmbedding."""
async def test_handle_request_success(
self, serving_embedding, basic_embedding_request, mock_request
):
"""Test successful embedding request handling."""
# Mock the generate_request to return expected data
async def mock_generate():
yield {
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
"meta_info": {"prompt_tokens": 5},
}
serving_embedding.tokenizer_manager.generate_request = Mock(
return_value=mock_generate()
)
response = await serving_embedding.handle_request(
basic_embedding_request, mock_request
)
assert isinstance(response, EmbeddingResponse)
assert len(response.data) == 1
assert response.data[0].embedding == [0.1, 0.2, 0.3, 0.4, 0.5]
async def test_handle_request_validation_error(
self, serving_embedding, mock_request
):
"""Test handling request with validation error."""
invalid_request = EmbeddingRequest(model="test-model", input="")
response = await serving_embedding.handle_request(invalid_request, mock_request)
assert isinstance(response, ORJSONResponse)
assert response.status_code == 400
async def test_handle_request_generation_error(
self, serving_embedding, basic_embedding_request, mock_request
):
"""Test handling request with generation error."""
# Mock generate_request to raise an error
async def mock_generate_error():
raise ValueError("Generation failed")
yield # This won't be reached but needed for async generator
serving_embedding.tokenizer_manager.generate_request = Mock(
return_value=mock_generate_error()
)
response = await serving_embedding.handle_request(
basic_embedding_request, mock_request
)
assert isinstance(response, ORJSONResponse)
assert response.status_code == 400
async def test_handle_request_internal_error(
self, serving_embedding, basic_embedding_request, mock_request
):
"""Test handling request with internal server error."""
# Mock _convert_to_internal_request to raise an exception
with patch.object(
serving_embedding,
"_convert_to_internal_request",
side_effect=Exception("Internal error"),
):
response = await serving_embedding.handle_request(
basic_embedding_request, mock_request
)
assert isinstance(response, ORJSONResponse)
assert response.status_code == 500