Files
sglang/test/srt/openai/test_serving_embedding.py
2025-06-21 13:21:06 -07:00

149 lines
5.4 KiB
Python

"""
Unit tests for the OpenAIServingEmbedding class from serving_embedding.py.
These tests ensure that the embedding serving implementation maintains compatibility
with the original adapter.py functionality and follows OpenAI API specifications.
"""
import unittest
import uuid
from unittest.mock import Mock
from fastapi import Request
from sglang.srt.entrypoints.openai.protocol import (
EmbeddingRequest,
EmbeddingResponse,
MultimodalEmbeddingInput,
)
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from sglang.srt.managers.io_struct import EmbeddingReqInput
# Mock TokenizerManager for embedding tests
class _MockTokenizerManager:
def __init__(self):
self.model_config = Mock()
self.model_config.is_multimodal = False
self.server_args = Mock()
self.server_args.enable_cache_report = False
self.model_path = "test-model"
# Mock tokenizer
self.tokenizer = Mock()
self.tokenizer.encode = Mock(return_value=[1, 2, 3, 4, 5])
self.tokenizer.decode = Mock(return_value="Test embedding input")
self.tokenizer.chat_template = None
self.tokenizer.bos_token_id = 1
# Mock generate_request method for embeddings
async def mock_generate_embedding():
yield {
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5] * 20, # 100-dim embedding
"meta_info": {
"id": f"embd-{uuid.uuid4()}",
"prompt_tokens": 5,
},
}
self.generate_request = Mock(return_value=mock_generate_embedding())
# Mock TemplateManager for embedding tests
class _MockTemplateManager:
def __init__(self):
self.chat_template_name = None # None for embeddings usually
self.jinja_template_content_format = None
self.completion_template_name = None
class ServingEmbeddingTestCase(unittest.TestCase):
def setUp(self):
"""Set up test fixtures."""
self.tokenizer_manager = _MockTokenizerManager()
self.template_manager = _MockTemplateManager()
self.serving_embedding = OpenAIServingEmbedding(
self.tokenizer_manager, self.template_manager
)
self.request = Mock(spec=Request)
self.request.headers = {}
self.basic_req = EmbeddingRequest(
model="test-model",
input="Hello, how are you?",
encoding_format="float",
)
self.list_req = EmbeddingRequest(
model="test-model",
input=["Hello, how are you?", "I am fine, thank you!"],
encoding_format="float",
)
self.multimodal_req = EmbeddingRequest(
model="test-model",
input=[
MultimodalEmbeddingInput(text="Hello", image="base64_image_data"),
MultimodalEmbeddingInput(text="World", image=None),
],
encoding_format="float",
)
self.token_ids_req = EmbeddingRequest(
model="test-model",
input=[1, 2, 3, 4, 5],
encoding_format="float",
)
def test_convert_single_string_request(self):
"""Test converting single string request to internal format."""
adapted_request, processed_request = (
self.serving_embedding._convert_to_internal_request(self.basic_req)
)
self.assertIsInstance(adapted_request, EmbeddingReqInput)
self.assertEqual(adapted_request.text, "Hello, how are you?")
# self.assertEqual(adapted_request.rid, "test-id")
self.assertEqual(processed_request, self.basic_req)
def test_convert_list_string_request(self):
"""Test converting list of strings request to internal format."""
adapted_request, processed_request = (
self.serving_embedding._convert_to_internal_request(self.list_req)
)
self.assertIsInstance(adapted_request, EmbeddingReqInput)
self.assertEqual(
adapted_request.text, ["Hello, how are you?", "I am fine, thank you!"]
)
# self.assertEqual(adapted_request.rid, "test-id")
self.assertEqual(processed_request, self.list_req)
def test_convert_token_ids_request(self):
"""Test converting token IDs request to internal format."""
adapted_request, processed_request = (
self.serving_embedding._convert_to_internal_request(self.token_ids_req)
)
self.assertIsInstance(adapted_request, EmbeddingReqInput)
self.assertEqual(adapted_request.input_ids, [1, 2, 3, 4, 5])
# self.assertEqual(adapted_request.rid, "test-id")
self.assertEqual(processed_request, self.token_ids_req)
def test_convert_multimodal_request(self):
"""Test converting multimodal request to internal format."""
adapted_request, processed_request = (
self.serving_embedding._convert_to_internal_request(self.multimodal_req)
)
self.assertIsInstance(adapted_request, EmbeddingReqInput)
# Should extract text and images separately
self.assertEqual(len(adapted_request.text), 2)
self.assertIn("Hello", adapted_request.text)
self.assertIn("World", adapted_request.text)
self.assertEqual(adapted_request.image_data[0], "base64_image_data")
self.assertIsNone(adapted_request.image_data[1])
# self.assertEqual(adapted_request.rid, "test-id")
if __name__ == "__main__":
unittest.main(verbosity=2)