Files
sglang/test/srt/openai/test_serving_embedding.py
Xinyuan Tong 70c471a868 [Refactor] OAI Server components (#7167)
Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
2025-06-16 20:45:20 -07:00

317 lines
10 KiB
Python

"""
Unit tests for the OpenAIServingEmbedding class from serving_embedding.py.
These tests ensure that the embedding serving implementation maintains compatibility
with the original adapter.py functionality and follows OpenAI API specifications.
"""
import asyncio
import json
import time
import uuid
from typing import Any, Dict, List
from unittest.mock import AsyncMock, Mock, patch
import pytest
from fastapi import Request
from fastapi.responses import ORJSONResponse
from pydantic_core import ValidationError
from sglang.srt.entrypoints.openai.protocol import (
EmbeddingObject,
EmbeddingRequest,
EmbeddingResponse,
ErrorResponse,
MultimodalEmbeddingInput,
UsageInfo,
)
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from sglang.srt.managers.io_struct import EmbeddingReqInput
# Mock TokenizerManager for embedding tests
class MockTokenizerManager:
def __init__(self):
self.model_config = Mock()
self.model_config.is_multimodal = False
self.server_args = Mock()
self.server_args.enable_cache_report = False
self.model_path = "test-model"
# Mock tokenizer
self.tokenizer = Mock()
self.tokenizer.encode = Mock(return_value=[1, 2, 3, 4, 5])
self.tokenizer.decode = Mock(return_value="Test embedding input")
self.tokenizer.chat_template = None
self.tokenizer.bos_token_id = 1
# Mock generate_request method for embeddings
async def mock_generate_embedding():
yield {
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5] * 20, # 100-dim embedding
"meta_info": {
"id": f"embd-{uuid.uuid4()}",
"prompt_tokens": 5,
},
}
self.generate_request = Mock(return_value=mock_generate_embedding())
@pytest.fixture
def mock_tokenizer_manager():
"""Create a mock tokenizer manager for testing."""
return MockTokenizerManager()
@pytest.fixture
def serving_embedding(mock_tokenizer_manager):
"""Create an OpenAIServingEmbedding instance for testing."""
return OpenAIServingEmbedding(mock_tokenizer_manager)
@pytest.fixture
def mock_request():
"""Create a mock FastAPI request."""
request = Mock(spec=Request)
request.headers = {}
return request
@pytest.fixture
def basic_embedding_request():
"""Create a basic embedding request."""
return EmbeddingRequest(
model="test-model",
input="Hello, how are you?",
encoding_format="float",
)
@pytest.fixture
def list_embedding_request():
"""Create an embedding request with list input."""
return EmbeddingRequest(
model="test-model",
input=["Hello, how are you?", "I am fine, thank you!"],
encoding_format="float",
)
@pytest.fixture
def multimodal_embedding_request():
"""Create a multimodal embedding request."""
return EmbeddingRequest(
model="test-model",
input=[
MultimodalEmbeddingInput(text="Hello", image="base64_image_data"),
MultimodalEmbeddingInput(text="World", image=None),
],
encoding_format="float",
)
@pytest.fixture
def token_ids_embedding_request():
"""Create an embedding request with token IDs."""
return EmbeddingRequest(
model="test-model",
input=[1, 2, 3, 4, 5],
encoding_format="float",
)
class TestOpenAIServingEmbeddingConversion:
"""Test request conversion methods."""
def test_convert_single_string_request(
self, serving_embedding, basic_embedding_request
):
"""Test converting single string request to internal format."""
adapted_request, processed_request = (
serving_embedding._convert_to_internal_request(
[basic_embedding_request], ["test-id"]
)
)
assert isinstance(adapted_request, EmbeddingReqInput)
assert adapted_request.text == "Hello, how are you?"
assert adapted_request.rid == "test-id"
assert processed_request == basic_embedding_request
def test_convert_list_string_request(
self, serving_embedding, list_embedding_request
):
"""Test converting list of strings request to internal format."""
adapted_request, processed_request = (
serving_embedding._convert_to_internal_request(
[list_embedding_request], ["test-id"]
)
)
assert isinstance(adapted_request, EmbeddingReqInput)
assert adapted_request.text == ["Hello, how are you?", "I am fine, thank you!"]
assert adapted_request.rid == "test-id"
assert processed_request == list_embedding_request
def test_convert_token_ids_request(
self, serving_embedding, token_ids_embedding_request
):
"""Test converting token IDs request to internal format."""
adapted_request, processed_request = (
serving_embedding._convert_to_internal_request(
[token_ids_embedding_request], ["test-id"]
)
)
assert isinstance(adapted_request, EmbeddingReqInput)
assert adapted_request.input_ids == [1, 2, 3, 4, 5]
assert adapted_request.rid == "test-id"
assert processed_request == token_ids_embedding_request
def test_convert_multimodal_request(
self, serving_embedding, multimodal_embedding_request
):
"""Test converting multimodal request to internal format."""
adapted_request, processed_request = (
serving_embedding._convert_to_internal_request(
[multimodal_embedding_request], ["test-id"]
)
)
assert isinstance(adapted_request, EmbeddingReqInput)
# Should extract text and images separately
assert len(adapted_request.text) == 2
assert "Hello" in adapted_request.text
assert "World" in adapted_request.text
assert adapted_request.image_data[0] == "base64_image_data"
assert adapted_request.image_data[1] is None
assert adapted_request.rid == "test-id"
class TestEmbeddingResponseBuilding:
"""Test response building methods."""
def test_build_single_embedding_response(self, serving_embedding):
"""Test building response for single embedding."""
ret_data = [
{
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
"meta_info": {"prompt_tokens": 5},
}
]
response = serving_embedding._build_embedding_response(ret_data, "test-model")
assert isinstance(response, EmbeddingResponse)
assert response.model == "test-model"
assert len(response.data) == 1
assert response.data[0].embedding == [0.1, 0.2, 0.3, 0.4, 0.5]
assert response.data[0].index == 0
assert response.data[0].object == "embedding"
assert response.usage.prompt_tokens == 5
assert response.usage.total_tokens == 5
assert response.usage.completion_tokens == 0
def test_build_multiple_embedding_response(self, serving_embedding):
"""Test building response for multiple embeddings."""
ret_data = [
{
"embedding": [0.1, 0.2, 0.3],
"meta_info": {"prompt_tokens": 3},
},
{
"embedding": [0.4, 0.5, 0.6],
"meta_info": {"prompt_tokens": 4},
},
]
response = serving_embedding._build_embedding_response(ret_data, "test-model")
assert isinstance(response, EmbeddingResponse)
assert len(response.data) == 2
assert response.data[0].embedding == [0.1, 0.2, 0.3]
assert response.data[0].index == 0
assert response.data[1].embedding == [0.4, 0.5, 0.6]
assert response.data[1].index == 1
assert response.usage.prompt_tokens == 7 # 3 + 4
assert response.usage.total_tokens == 7
@pytest.mark.asyncio
class TestOpenAIServingEmbeddingAsyncMethods:
"""Test async methods of OpenAIServingEmbedding."""
async def test_handle_request_success(
self, serving_embedding, basic_embedding_request, mock_request
):
"""Test successful embedding request handling."""
# Mock the generate_request to return expected data
async def mock_generate():
yield {
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
"meta_info": {"prompt_tokens": 5},
}
serving_embedding.tokenizer_manager.generate_request = Mock(
return_value=mock_generate()
)
response = await serving_embedding.handle_request(
basic_embedding_request, mock_request
)
assert isinstance(response, EmbeddingResponse)
assert len(response.data) == 1
assert response.data[0].embedding == [0.1, 0.2, 0.3, 0.4, 0.5]
async def test_handle_request_validation_error(
self, serving_embedding, mock_request
):
"""Test handling request with validation error."""
invalid_request = EmbeddingRequest(model="test-model", input="")
response = await serving_embedding.handle_request(invalid_request, mock_request)
assert isinstance(response, ORJSONResponse)
assert response.status_code == 400
async def test_handle_request_generation_error(
self, serving_embedding, basic_embedding_request, mock_request
):
"""Test handling request with generation error."""
# Mock generate_request to raise an error
async def mock_generate_error():
raise ValueError("Generation failed")
yield # This won't be reached but needed for async generator
serving_embedding.tokenizer_manager.generate_request = Mock(
return_value=mock_generate_error()
)
response = await serving_embedding.handle_request(
basic_embedding_request, mock_request
)
assert isinstance(response, ORJSONResponse)
assert response.status_code == 400
async def test_handle_request_internal_error(
self, serving_embedding, basic_embedding_request, mock_request
):
"""Test handling request with internal server error."""
# Mock _convert_to_internal_request to raise an exception
with patch.object(
serving_embedding,
"_convert_to_internal_request",
side_effect=Exception("Internal error"),
):
response = await serving_embedding.handle_request(
basic_embedding_request, mock_request
)
assert isinstance(response, ORJSONResponse)
assert response.status_code == 500