feat(oai refactor): Replace openai_api with entrypoints/openai (#7351)
Co-authored-by: Jin Pan <jpan236@wisc.edu>
This commit is contained in:
@@ -5,25 +5,16 @@ These tests ensure that the embedding serving implementation maintains compatibi
|
||||
with the original adapter.py functionality and follows OpenAI API specifications.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import unittest
|
||||
import uuid
|
||||
from typing import Any, Dict, List
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
from unittest.mock import Mock
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from pydantic_core import ValidationError
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import (
|
||||
EmbeddingObject,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
ErrorResponse,
|
||||
MultimodalEmbeddingInput,
|
||||
UsageInfo,
|
||||
)
|
||||
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
from sglang.srt.managers.io_struct import EmbeddingReqInput
|
||||
@@ -58,11 +49,22 @@ class _MockTokenizerManager:
|
||||
self.generate_request = Mock(return_value=mock_generate_embedding())
|
||||
|
||||
|
||||
# Mock TemplateManager for embedding tests
|
||||
class _MockTemplateManager:
|
||||
def __init__(self):
|
||||
self.chat_template_name = None # None for embeddings usually
|
||||
self.jinja_template_content_format = None
|
||||
self.completion_template_name = None
|
||||
|
||||
|
||||
class ServingEmbeddingTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.tokenizer_manager = _MockTokenizerManager()
|
||||
self.serving_embedding = OpenAIServingEmbedding(self.tokenizer_manager)
|
||||
self.template_manager = _MockTemplateManager()
|
||||
self.serving_embedding = OpenAIServingEmbedding(
|
||||
self.tokenizer_manager, self.template_manager
|
||||
)
|
||||
|
||||
self.request = Mock(spec=Request)
|
||||
self.request.headers = {}
|
||||
@@ -141,132 +143,6 @@ class ServingEmbeddingTestCase(unittest.TestCase):
|
||||
self.assertIsNone(adapted_request.image_data[1])
|
||||
# self.assertEqual(adapted_request.rid, "test-id")
|
||||
|
||||
def test_build_single_embedding_response(self):
|
||||
"""Test building response for single embedding."""
|
||||
ret_data = [
|
||||
{
|
||||
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
"meta_info": {"prompt_tokens": 5},
|
||||
}
|
||||
]
|
||||
|
||||
response = self.serving_embedding._build_embedding_response(ret_data)
|
||||
|
||||
self.assertIsInstance(response, EmbeddingResponse)
|
||||
self.assertEqual(response.model, "test-model")
|
||||
self.assertEqual(len(response.data), 1)
|
||||
self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3, 0.4, 0.5])
|
||||
self.assertEqual(response.data[0].index, 0)
|
||||
self.assertEqual(response.data[0].object, "embedding")
|
||||
self.assertEqual(response.usage.prompt_tokens, 5)
|
||||
self.assertEqual(response.usage.total_tokens, 5)
|
||||
self.assertEqual(response.usage.completion_tokens, 0)
|
||||
|
||||
def test_build_multiple_embedding_response(self):
|
||||
"""Test building response for multiple embeddings."""
|
||||
ret_data = [
|
||||
{
|
||||
"embedding": [0.1, 0.2, 0.3],
|
||||
"meta_info": {"prompt_tokens": 3},
|
||||
},
|
||||
{
|
||||
"embedding": [0.4, 0.5, 0.6],
|
||||
"meta_info": {"prompt_tokens": 4},
|
||||
},
|
||||
]
|
||||
|
||||
response = self.serving_embedding._build_embedding_response(ret_data)
|
||||
|
||||
self.assertIsInstance(response, EmbeddingResponse)
|
||||
self.assertEqual(len(response.data), 2)
|
||||
self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3])
|
||||
self.assertEqual(response.data[0].index, 0)
|
||||
self.assertEqual(response.data[1].embedding, [0.4, 0.5, 0.6])
|
||||
self.assertEqual(response.data[1].index, 1)
|
||||
self.assertEqual(response.usage.prompt_tokens, 7) # 3 + 4
|
||||
self.assertEqual(response.usage.total_tokens, 7)
|
||||
|
||||
def test_handle_request_success(self):
|
||||
"""Test successful embedding request handling."""
|
||||
|
||||
async def run_test():
|
||||
# Mock the generate_request to return expected data
|
||||
async def mock_generate():
|
||||
yield {
|
||||
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
"meta_info": {"prompt_tokens": 5},
|
||||
}
|
||||
|
||||
self.serving_embedding.tokenizer_manager.generate_request = Mock(
|
||||
return_value=mock_generate()
|
||||
)
|
||||
|
||||
response = await self.serving_embedding.handle_request(
|
||||
self.basic_req, self.request
|
||||
)
|
||||
|
||||
self.assertIsInstance(response, EmbeddingResponse)
|
||||
self.assertEqual(len(response.data), 1)
|
||||
self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3, 0.4, 0.5])
|
||||
|
||||
asyncio.run(run_test())
|
||||
|
||||
def test_handle_request_validation_error(self):
|
||||
"""Test handling request with validation error."""
|
||||
|
||||
async def run_test():
|
||||
invalid_request = EmbeddingRequest(model="test-model", input="")
|
||||
|
||||
response = await self.serving_embedding.handle_request(
|
||||
invalid_request, self.request
|
||||
)
|
||||
|
||||
self.assertIsInstance(response, ORJSONResponse)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
|
||||
asyncio.run(run_test())
|
||||
|
||||
def test_handle_request_generation_error(self):
|
||||
"""Test handling request with generation error."""
|
||||
|
||||
async def run_test():
|
||||
# Mock generate_request to raise an error
|
||||
async def mock_generate_error():
|
||||
raise ValueError("Generation failed")
|
||||
yield # This won't be reached but needed for async generator
|
||||
|
||||
self.serving_embedding.tokenizer_manager.generate_request = Mock(
|
||||
return_value=mock_generate_error()
|
||||
)
|
||||
|
||||
response = await self.serving_embedding.handle_request(
|
||||
self.basic_req, self.request
|
||||
)
|
||||
|
||||
self.assertIsInstance(response, ORJSONResponse)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
|
||||
asyncio.run(run_test())
|
||||
|
||||
def test_handle_request_internal_error(self):
|
||||
"""Test handling request with internal server error."""
|
||||
|
||||
async def run_test():
|
||||
# Mock _convert_to_internal_request to raise an exception
|
||||
with patch.object(
|
||||
self.serving_embedding,
|
||||
"_convert_to_internal_request",
|
||||
side_effect=Exception("Internal error"),
|
||||
):
|
||||
response = await self.serving_embedding.handle_request(
|
||||
self.basic_req, self.request
|
||||
)
|
||||
|
||||
self.assertIsInstance(response, ORJSONResponse)
|
||||
self.assertEqual(response.status_code, 500)
|
||||
|
||||
asyncio.run(run_test())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
Reference in New Issue
Block a user