Add more refactored openai test & in CI (#7284)
This commit is contained in:
@@ -8,11 +8,11 @@ with the original adapter.py functionality and follows OpenAI API specifications
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import unittest
|
||||
import uuid
|
||||
from typing import Any, Dict, List
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import Request
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from pydantic_core import ValidationError
|
||||
@@ -30,7 +30,7 @@ from sglang.srt.managers.io_struct import EmbeddingReqInput
|
||||
|
||||
|
||||
# Mock TokenizerManager for embedding tests
|
||||
class MockTokenizerManager:
|
||||
class _MockTokenizerManager:
|
||||
def __init__(self):
|
||||
self.model_config = Mock()
|
||||
self.model_config.is_multimodal = False
|
||||
@@ -58,141 +58,98 @@ class MockTokenizerManager:
|
||||
self.generate_request = Mock(return_value=mock_generate_embedding())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tokenizer_manager():
|
||||
"""Create a mock tokenizer manager for testing."""
|
||||
return MockTokenizerManager()
|
||||
class ServingEmbeddingTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.tokenizer_manager = _MockTokenizerManager()
|
||||
self.serving_embedding = OpenAIServingEmbedding(self.tokenizer_manager)
|
||||
|
||||
self.request = Mock(spec=Request)
|
||||
self.request.headers = {}
|
||||
|
||||
@pytest.fixture
|
||||
def serving_embedding(mock_tokenizer_manager):
|
||||
"""Create an OpenAIServingEmbedding instance for testing."""
|
||||
return OpenAIServingEmbedding(mock_tokenizer_manager)
|
||||
self.basic_req = EmbeddingRequest(
|
||||
model="test-model",
|
||||
input="Hello, how are you?",
|
||||
encoding_format="float",
|
||||
)
|
||||
self.list_req = EmbeddingRequest(
|
||||
model="test-model",
|
||||
input=["Hello, how are you?", "I am fine, thank you!"],
|
||||
encoding_format="float",
|
||||
)
|
||||
self.multimodal_req = EmbeddingRequest(
|
||||
model="test-model",
|
||||
input=[
|
||||
MultimodalEmbeddingInput(text="Hello", image="base64_image_data"),
|
||||
MultimodalEmbeddingInput(text="World", image=None),
|
||||
],
|
||||
encoding_format="float",
|
||||
)
|
||||
self.token_ids_req = EmbeddingRequest(
|
||||
model="test-model",
|
||||
input=[1, 2, 3, 4, 5],
|
||||
encoding_format="float",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request():
|
||||
"""Create a mock FastAPI request."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {}
|
||||
return request
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def basic_embedding_request():
|
||||
"""Create a basic embedding request."""
|
||||
return EmbeddingRequest(
|
||||
model="test-model",
|
||||
input="Hello, how are you?",
|
||||
encoding_format="float",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def list_embedding_request():
|
||||
"""Create an embedding request with list input."""
|
||||
return EmbeddingRequest(
|
||||
model="test-model",
|
||||
input=["Hello, how are you?", "I am fine, thank you!"],
|
||||
encoding_format="float",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def multimodal_embedding_request():
|
||||
"""Create a multimodal embedding request."""
|
||||
return EmbeddingRequest(
|
||||
model="test-model",
|
||||
input=[
|
||||
MultimodalEmbeddingInput(text="Hello", image="base64_image_data"),
|
||||
MultimodalEmbeddingInput(text="World", image=None),
|
||||
],
|
||||
encoding_format="float",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def token_ids_embedding_request():
|
||||
"""Create an embedding request with token IDs."""
|
||||
return EmbeddingRequest(
|
||||
model="test-model",
|
||||
input=[1, 2, 3, 4, 5],
|
||||
encoding_format="float",
|
||||
)
|
||||
|
||||
|
||||
class TestOpenAIServingEmbeddingConversion:
|
||||
"""Test request conversion methods."""
|
||||
|
||||
def test_convert_single_string_request(
|
||||
self, serving_embedding, basic_embedding_request
|
||||
):
|
||||
def test_convert_single_string_request(self):
|
||||
"""Test converting single string request to internal format."""
|
||||
adapted_request, processed_request = (
|
||||
serving_embedding._convert_to_internal_request(
|
||||
[basic_embedding_request], ["test-id"]
|
||||
self.serving_embedding._convert_to_internal_request(
|
||||
[self.basic_req], ["test-id"]
|
||||
)
|
||||
)
|
||||
|
||||
assert isinstance(adapted_request, EmbeddingReqInput)
|
||||
assert adapted_request.text == "Hello, how are you?"
|
||||
assert adapted_request.rid == "test-id"
|
||||
assert processed_request == basic_embedding_request
|
||||
self.assertIsInstance(adapted_request, EmbeddingReqInput)
|
||||
self.assertEqual(adapted_request.text, "Hello, how are you?")
|
||||
self.assertEqual(adapted_request.rid, "test-id")
|
||||
self.assertEqual(processed_request, self.basic_req)
|
||||
|
||||
def test_convert_list_string_request(
|
||||
self, serving_embedding, list_embedding_request
|
||||
):
|
||||
def test_convert_list_string_request(self):
|
||||
"""Test converting list of strings request to internal format."""
|
||||
adapted_request, processed_request = (
|
||||
serving_embedding._convert_to_internal_request(
|
||||
[list_embedding_request], ["test-id"]
|
||||
self.serving_embedding._convert_to_internal_request(
|
||||
[self.list_req], ["test-id"]
|
||||
)
|
||||
)
|
||||
|
||||
assert isinstance(adapted_request, EmbeddingReqInput)
|
||||
assert adapted_request.text == ["Hello, how are you?", "I am fine, thank you!"]
|
||||
assert adapted_request.rid == "test-id"
|
||||
assert processed_request == list_embedding_request
|
||||
self.assertIsInstance(adapted_request, EmbeddingReqInput)
|
||||
self.assertEqual(
|
||||
adapted_request.text, ["Hello, how are you?", "I am fine, thank you!"]
|
||||
)
|
||||
self.assertEqual(adapted_request.rid, "test-id")
|
||||
self.assertEqual(processed_request, self.list_req)
|
||||
|
||||
def test_convert_token_ids_request(
|
||||
self, serving_embedding, token_ids_embedding_request
|
||||
):
|
||||
def test_convert_token_ids_request(self):
|
||||
"""Test converting token IDs request to internal format."""
|
||||
adapted_request, processed_request = (
|
||||
serving_embedding._convert_to_internal_request(
|
||||
[token_ids_embedding_request], ["test-id"]
|
||||
self.serving_embedding._convert_to_internal_request(
|
||||
[self.token_ids_req], ["test-id"]
|
||||
)
|
||||
)
|
||||
|
||||
assert isinstance(adapted_request, EmbeddingReqInput)
|
||||
assert adapted_request.input_ids == [1, 2, 3, 4, 5]
|
||||
assert adapted_request.rid == "test-id"
|
||||
assert processed_request == token_ids_embedding_request
|
||||
self.assertIsInstance(adapted_request, EmbeddingReqInput)
|
||||
self.assertEqual(adapted_request.input_ids, [1, 2, 3, 4, 5])
|
||||
self.assertEqual(adapted_request.rid, "test-id")
|
||||
self.assertEqual(processed_request, self.token_ids_req)
|
||||
|
||||
def test_convert_multimodal_request(
|
||||
self, serving_embedding, multimodal_embedding_request
|
||||
):
|
||||
def test_convert_multimodal_request(self):
|
||||
"""Test converting multimodal request to internal format."""
|
||||
adapted_request, processed_request = (
|
||||
serving_embedding._convert_to_internal_request(
|
||||
[multimodal_embedding_request], ["test-id"]
|
||||
self.serving_embedding._convert_to_internal_request(
|
||||
[self.multimodal_req], ["test-id"]
|
||||
)
|
||||
)
|
||||
|
||||
assert isinstance(adapted_request, EmbeddingReqInput)
|
||||
self.assertIsInstance(adapted_request, EmbeddingReqInput)
|
||||
# Should extract text and images separately
|
||||
assert len(adapted_request.text) == 2
|
||||
assert "Hello" in adapted_request.text
|
||||
assert "World" in adapted_request.text
|
||||
assert adapted_request.image_data[0] == "base64_image_data"
|
||||
assert adapted_request.image_data[1] is None
|
||||
assert adapted_request.rid == "test-id"
|
||||
self.assertEqual(len(adapted_request.text), 2)
|
||||
self.assertIn("Hello", adapted_request.text)
|
||||
self.assertIn("World", adapted_request.text)
|
||||
self.assertEqual(adapted_request.image_data[0], "base64_image_data")
|
||||
self.assertIsNone(adapted_request.image_data[1])
|
||||
self.assertEqual(adapted_request.rid, "test-id")
|
||||
|
||||
|
||||
class TestEmbeddingResponseBuilding:
|
||||
"""Test response building methods."""
|
||||
|
||||
def test_build_single_embedding_response(self, serving_embedding):
|
||||
def test_build_single_embedding_response(self):
|
||||
"""Test building response for single embedding."""
|
||||
ret_data = [
|
||||
{
|
||||
@@ -201,19 +158,21 @@ class TestEmbeddingResponseBuilding:
|
||||
}
|
||||
]
|
||||
|
||||
response = serving_embedding._build_embedding_response(ret_data, "test-model")
|
||||
response = self.serving_embedding._build_embedding_response(
|
||||
ret_data, "test-model"
|
||||
)
|
||||
|
||||
assert isinstance(response, EmbeddingResponse)
|
||||
assert response.model == "test-model"
|
||||
assert len(response.data) == 1
|
||||
assert response.data[0].embedding == [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||
assert response.data[0].index == 0
|
||||
assert response.data[0].object == "embedding"
|
||||
assert response.usage.prompt_tokens == 5
|
||||
assert response.usage.total_tokens == 5
|
||||
assert response.usage.completion_tokens == 0
|
||||
self.assertIsInstance(response, EmbeddingResponse)
|
||||
self.assertEqual(response.model, "test-model")
|
||||
self.assertEqual(len(response.data), 1)
|
||||
self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3, 0.4, 0.5])
|
||||
self.assertEqual(response.data[0].index, 0)
|
||||
self.assertEqual(response.data[0].object, "embedding")
|
||||
self.assertEqual(response.usage.prompt_tokens, 5)
|
||||
self.assertEqual(response.usage.total_tokens, 5)
|
||||
self.assertEqual(response.usage.completion_tokens, 0)
|
||||
|
||||
def test_build_multiple_embedding_response(self, serving_embedding):
|
||||
def test_build_multiple_embedding_response(self):
|
||||
"""Test building response for multiple embeddings."""
|
||||
ret_data = [
|
||||
{
|
||||
@@ -226,25 +185,20 @@ class TestEmbeddingResponseBuilding:
|
||||
},
|
||||
]
|
||||
|
||||
response = serving_embedding._build_embedding_response(ret_data, "test-model")
|
||||
response = self.serving_embedding._build_embedding_response(
|
||||
ret_data, "test-model"
|
||||
)
|
||||
|
||||
assert isinstance(response, EmbeddingResponse)
|
||||
assert len(response.data) == 2
|
||||
assert response.data[0].embedding == [0.1, 0.2, 0.3]
|
||||
assert response.data[0].index == 0
|
||||
assert response.data[1].embedding == [0.4, 0.5, 0.6]
|
||||
assert response.data[1].index == 1
|
||||
assert response.usage.prompt_tokens == 7 # 3 + 4
|
||||
assert response.usage.total_tokens == 7
|
||||
self.assertIsInstance(response, EmbeddingResponse)
|
||||
self.assertEqual(len(response.data), 2)
|
||||
self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3])
|
||||
self.assertEqual(response.data[0].index, 0)
|
||||
self.assertEqual(response.data[1].embedding, [0.4, 0.5, 0.6])
|
||||
self.assertEqual(response.data[1].index, 1)
|
||||
self.assertEqual(response.usage.prompt_tokens, 7) # 3 + 4
|
||||
self.assertEqual(response.usage.total_tokens, 7)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestOpenAIServingEmbeddingAsyncMethods:
|
||||
"""Test async methods of OpenAIServingEmbedding."""
|
||||
|
||||
async def test_handle_request_success(
|
||||
self, serving_embedding, basic_embedding_request, mock_request
|
||||
):
|
||||
async def test_handle_request_success(self):
|
||||
"""Test successful embedding request handling."""
|
||||
|
||||
# Mock the generate_request to return expected data
|
||||
@@ -254,32 +208,30 @@ class TestOpenAIServingEmbeddingAsyncMethods:
|
||||
"meta_info": {"prompt_tokens": 5},
|
||||
}
|
||||
|
||||
serving_embedding.tokenizer_manager.generate_request = Mock(
|
||||
self.serving_embedding.tokenizer_manager.generate_request = Mock(
|
||||
return_value=mock_generate()
|
||||
)
|
||||
|
||||
response = await serving_embedding.handle_request(
|
||||
basic_embedding_request, mock_request
|
||||
response = await self.serving_embedding.handle_request(
|
||||
self.basic_req, self.request
|
||||
)
|
||||
|
||||
assert isinstance(response, EmbeddingResponse)
|
||||
assert len(response.data) == 1
|
||||
assert response.data[0].embedding == [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||
self.assertIsInstance(response, EmbeddingResponse)
|
||||
self.assertEqual(len(response.data), 1)
|
||||
self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3, 0.4, 0.5])
|
||||
|
||||
async def test_handle_request_validation_error(
|
||||
self, serving_embedding, mock_request
|
||||
):
|
||||
async def test_handle_request_validation_error(self):
|
||||
"""Test handling request with validation error."""
|
||||
invalid_request = EmbeddingRequest(model="test-model", input="")
|
||||
|
||||
response = await serving_embedding.handle_request(invalid_request, mock_request)
|
||||
response = await self.serving_embedding.handle_request(
|
||||
invalid_request, self.request
|
||||
)
|
||||
|
||||
assert isinstance(response, ORJSONResponse)
|
||||
assert response.status_code == 400
|
||||
self.assertIsInstance(response, ORJSONResponse)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
|
||||
async def test_handle_request_generation_error(
|
||||
self, serving_embedding, basic_embedding_request, mock_request
|
||||
):
|
||||
async def test_handle_request_generation_error(self):
|
||||
"""Test handling request with generation error."""
|
||||
|
||||
# Mock generate_request to raise an error
|
||||
@@ -287,30 +239,32 @@ class TestOpenAIServingEmbeddingAsyncMethods:
|
||||
raise ValueError("Generation failed")
|
||||
yield # This won't be reached but needed for async generator
|
||||
|
||||
serving_embedding.tokenizer_manager.generate_request = Mock(
|
||||
self.serving_embedding.tokenizer_manager.generate_request = Mock(
|
||||
return_value=mock_generate_error()
|
||||
)
|
||||
|
||||
response = await serving_embedding.handle_request(
|
||||
basic_embedding_request, mock_request
|
||||
response = await self.serving_embedding.handle_request(
|
||||
self.basic_req, self.request
|
||||
)
|
||||
|
||||
assert isinstance(response, ORJSONResponse)
|
||||
assert response.status_code == 400
|
||||
self.assertIsInstance(response, ORJSONResponse)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
|
||||
async def test_handle_request_internal_error(
|
||||
self, serving_embedding, basic_embedding_request, mock_request
|
||||
):
|
||||
async def test_handle_request_internal_error(self):
|
||||
"""Test handling request with internal server error."""
|
||||
# Mock _convert_to_internal_request to raise an exception
|
||||
with patch.object(
|
||||
serving_embedding,
|
||||
self.serving_embedding,
|
||||
"_convert_to_internal_request",
|
||||
side_effect=Exception("Internal error"),
|
||||
):
|
||||
response = await serving_embedding.handle_request(
|
||||
basic_embedding_request, mock_request
|
||||
response = await self.serving_embedding.handle_request(
|
||||
self.basic_req, self.request
|
||||
)
|
||||
|
||||
assert isinstance(response, ORJSONResponse)
|
||||
assert response.status_code == 500
|
||||
self.assertIsInstance(response, ORJSONResponse)
|
||||
self.assertEqual(response.status_code, 500)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
Reference in New Issue
Block a user