Add more refactored openai test & in CI (#7284)

This commit is contained in:
Jinn
2025-06-18 15:52:55 -05:00
committed by GitHub
parent 09ae5b20f3
commit ffd1a26e09
8 changed files with 566 additions and 1049 deletions

View File

@@ -8,11 +8,11 @@ with the original adapter.py functionality and follows OpenAI API specifications
import asyncio
import json
import time
import unittest
import uuid
from typing import Any, Dict, List
from unittest.mock import AsyncMock, Mock, patch
import pytest
from fastapi import Request
from fastapi.responses import ORJSONResponse
from pydantic_core import ValidationError
@@ -30,7 +30,7 @@ from sglang.srt.managers.io_struct import EmbeddingReqInput
# Mock TokenizerManager for embedding tests
class MockTokenizerManager:
class _MockTokenizerManager:
def __init__(self):
self.model_config = Mock()
self.model_config.is_multimodal = False
@@ -58,141 +58,98 @@ class MockTokenizerManager:
self.generate_request = Mock(return_value=mock_generate_embedding())
@pytest.fixture
def mock_tokenizer_manager():
"""Create a mock tokenizer manager for testing."""
return MockTokenizerManager()
class ServingEmbeddingTestCase(unittest.TestCase):
def setUp(self):
"""Set up test fixtures."""
self.tokenizer_manager = _MockTokenizerManager()
self.serving_embedding = OpenAIServingEmbedding(self.tokenizer_manager)
self.request = Mock(spec=Request)
self.request.headers = {}
@pytest.fixture
def serving_embedding(mock_tokenizer_manager):
"""Create an OpenAIServingEmbedding instance for testing."""
return OpenAIServingEmbedding(mock_tokenizer_manager)
self.basic_req = EmbeddingRequest(
model="test-model",
input="Hello, how are you?",
encoding_format="float",
)
self.list_req = EmbeddingRequest(
model="test-model",
input=["Hello, how are you?", "I am fine, thank you!"],
encoding_format="float",
)
self.multimodal_req = EmbeddingRequest(
model="test-model",
input=[
MultimodalEmbeddingInput(text="Hello", image="base64_image_data"),
MultimodalEmbeddingInput(text="World", image=None),
],
encoding_format="float",
)
self.token_ids_req = EmbeddingRequest(
model="test-model",
input=[1, 2, 3, 4, 5],
encoding_format="float",
)
@pytest.fixture
def mock_request():
"""Create a mock FastAPI request."""
request = Mock(spec=Request)
request.headers = {}
return request
@pytest.fixture
def basic_embedding_request():
"""Create a basic embedding request."""
return EmbeddingRequest(
model="test-model",
input="Hello, how are you?",
encoding_format="float",
)
@pytest.fixture
def list_embedding_request():
"""Create an embedding request with list input."""
return EmbeddingRequest(
model="test-model",
input=["Hello, how are you?", "I am fine, thank you!"],
encoding_format="float",
)
@pytest.fixture
def multimodal_embedding_request():
"""Create a multimodal embedding request."""
return EmbeddingRequest(
model="test-model",
input=[
MultimodalEmbeddingInput(text="Hello", image="base64_image_data"),
MultimodalEmbeddingInput(text="World", image=None),
],
encoding_format="float",
)
@pytest.fixture
def token_ids_embedding_request():
"""Create an embedding request with token IDs."""
return EmbeddingRequest(
model="test-model",
input=[1, 2, 3, 4, 5],
encoding_format="float",
)
class TestOpenAIServingEmbeddingConversion:
"""Test request conversion methods."""
def test_convert_single_string_request(
self, serving_embedding, basic_embedding_request
):
def test_convert_single_string_request(self):
"""Test converting single string request to internal format."""
adapted_request, processed_request = (
serving_embedding._convert_to_internal_request(
[basic_embedding_request], ["test-id"]
self.serving_embedding._convert_to_internal_request(
[self.basic_req], ["test-id"]
)
)
assert isinstance(adapted_request, EmbeddingReqInput)
assert adapted_request.text == "Hello, how are you?"
assert adapted_request.rid == "test-id"
assert processed_request == basic_embedding_request
self.assertIsInstance(adapted_request, EmbeddingReqInput)
self.assertEqual(adapted_request.text, "Hello, how are you?")
self.assertEqual(adapted_request.rid, "test-id")
self.assertEqual(processed_request, self.basic_req)
def test_convert_list_string_request(
self, serving_embedding, list_embedding_request
):
def test_convert_list_string_request(self):
"""Test converting list of strings request to internal format."""
adapted_request, processed_request = (
serving_embedding._convert_to_internal_request(
[list_embedding_request], ["test-id"]
self.serving_embedding._convert_to_internal_request(
[self.list_req], ["test-id"]
)
)
assert isinstance(adapted_request, EmbeddingReqInput)
assert adapted_request.text == ["Hello, how are you?", "I am fine, thank you!"]
assert adapted_request.rid == "test-id"
assert processed_request == list_embedding_request
self.assertIsInstance(adapted_request, EmbeddingReqInput)
self.assertEqual(
adapted_request.text, ["Hello, how are you?", "I am fine, thank you!"]
)
self.assertEqual(adapted_request.rid, "test-id")
self.assertEqual(processed_request, self.list_req)
def test_convert_token_ids_request(
self, serving_embedding, token_ids_embedding_request
):
def test_convert_token_ids_request(self):
"""Test converting token IDs request to internal format."""
adapted_request, processed_request = (
serving_embedding._convert_to_internal_request(
[token_ids_embedding_request], ["test-id"]
self.serving_embedding._convert_to_internal_request(
[self.token_ids_req], ["test-id"]
)
)
assert isinstance(adapted_request, EmbeddingReqInput)
assert adapted_request.input_ids == [1, 2, 3, 4, 5]
assert adapted_request.rid == "test-id"
assert processed_request == token_ids_embedding_request
self.assertIsInstance(adapted_request, EmbeddingReqInput)
self.assertEqual(adapted_request.input_ids, [1, 2, 3, 4, 5])
self.assertEqual(adapted_request.rid, "test-id")
self.assertEqual(processed_request, self.token_ids_req)
def test_convert_multimodal_request(
self, serving_embedding, multimodal_embedding_request
):
def test_convert_multimodal_request(self):
"""Test converting multimodal request to internal format."""
adapted_request, processed_request = (
serving_embedding._convert_to_internal_request(
[multimodal_embedding_request], ["test-id"]
self.serving_embedding._convert_to_internal_request(
[self.multimodal_req], ["test-id"]
)
)
assert isinstance(adapted_request, EmbeddingReqInput)
self.assertIsInstance(adapted_request, EmbeddingReqInput)
# Should extract text and images separately
assert len(adapted_request.text) == 2
assert "Hello" in adapted_request.text
assert "World" in adapted_request.text
assert adapted_request.image_data[0] == "base64_image_data"
assert adapted_request.image_data[1] is None
assert adapted_request.rid == "test-id"
self.assertEqual(len(adapted_request.text), 2)
self.assertIn("Hello", adapted_request.text)
self.assertIn("World", adapted_request.text)
self.assertEqual(adapted_request.image_data[0], "base64_image_data")
self.assertIsNone(adapted_request.image_data[1])
self.assertEqual(adapted_request.rid, "test-id")
class TestEmbeddingResponseBuilding:
"""Test response building methods."""
def test_build_single_embedding_response(self, serving_embedding):
def test_build_single_embedding_response(self):
"""Test building response for single embedding."""
ret_data = [
{
@@ -201,19 +158,21 @@ class TestEmbeddingResponseBuilding:
}
]
response = serving_embedding._build_embedding_response(ret_data, "test-model")
response = self.serving_embedding._build_embedding_response(
ret_data, "test-model"
)
assert isinstance(response, EmbeddingResponse)
assert response.model == "test-model"
assert len(response.data) == 1
assert response.data[0].embedding == [0.1, 0.2, 0.3, 0.4, 0.5]
assert response.data[0].index == 0
assert response.data[0].object == "embedding"
assert response.usage.prompt_tokens == 5
assert response.usage.total_tokens == 5
assert response.usage.completion_tokens == 0
self.assertIsInstance(response, EmbeddingResponse)
self.assertEqual(response.model, "test-model")
self.assertEqual(len(response.data), 1)
self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3, 0.4, 0.5])
self.assertEqual(response.data[0].index, 0)
self.assertEqual(response.data[0].object, "embedding")
self.assertEqual(response.usage.prompt_tokens, 5)
self.assertEqual(response.usage.total_tokens, 5)
self.assertEqual(response.usage.completion_tokens, 0)
def test_build_multiple_embedding_response(self, serving_embedding):
def test_build_multiple_embedding_response(self):
"""Test building response for multiple embeddings."""
ret_data = [
{
@@ -226,25 +185,20 @@ class TestEmbeddingResponseBuilding:
},
]
response = serving_embedding._build_embedding_response(ret_data, "test-model")
response = self.serving_embedding._build_embedding_response(
ret_data, "test-model"
)
assert isinstance(response, EmbeddingResponse)
assert len(response.data) == 2
assert response.data[0].embedding == [0.1, 0.2, 0.3]
assert response.data[0].index == 0
assert response.data[1].embedding == [0.4, 0.5, 0.6]
assert response.data[1].index == 1
assert response.usage.prompt_tokens == 7 # 3 + 4
assert response.usage.total_tokens == 7
self.assertIsInstance(response, EmbeddingResponse)
self.assertEqual(len(response.data), 2)
self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3])
self.assertEqual(response.data[0].index, 0)
self.assertEqual(response.data[1].embedding, [0.4, 0.5, 0.6])
self.assertEqual(response.data[1].index, 1)
self.assertEqual(response.usage.prompt_tokens, 7) # 3 + 4
self.assertEqual(response.usage.total_tokens, 7)
@pytest.mark.asyncio
class TestOpenAIServingEmbeddingAsyncMethods:
"""Test async methods of OpenAIServingEmbedding."""
async def test_handle_request_success(
self, serving_embedding, basic_embedding_request, mock_request
):
async def test_handle_request_success(self):
"""Test successful embedding request handling."""
# Mock the generate_request to return expected data
@@ -254,32 +208,30 @@ class TestOpenAIServingEmbeddingAsyncMethods:
"meta_info": {"prompt_tokens": 5},
}
serving_embedding.tokenizer_manager.generate_request = Mock(
self.serving_embedding.tokenizer_manager.generate_request = Mock(
return_value=mock_generate()
)
response = await serving_embedding.handle_request(
basic_embedding_request, mock_request
response = await self.serving_embedding.handle_request(
self.basic_req, self.request
)
assert isinstance(response, EmbeddingResponse)
assert len(response.data) == 1
assert response.data[0].embedding == [0.1, 0.2, 0.3, 0.4, 0.5]
self.assertIsInstance(response, EmbeddingResponse)
self.assertEqual(len(response.data), 1)
self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3, 0.4, 0.5])
async def test_handle_request_validation_error(
self, serving_embedding, mock_request
):
async def test_handle_request_validation_error(self):
"""Test handling request with validation error."""
invalid_request = EmbeddingRequest(model="test-model", input="")
response = await serving_embedding.handle_request(invalid_request, mock_request)
response = await self.serving_embedding.handle_request(
invalid_request, self.request
)
assert isinstance(response, ORJSONResponse)
assert response.status_code == 400
self.assertIsInstance(response, ORJSONResponse)
self.assertEqual(response.status_code, 400)
async def test_handle_request_generation_error(
self, serving_embedding, basic_embedding_request, mock_request
):
async def test_handle_request_generation_error(self):
"""Test handling request with generation error."""
# Mock generate_request to raise an error
@@ -287,30 +239,32 @@ class TestOpenAIServingEmbeddingAsyncMethods:
raise ValueError("Generation failed")
yield # This won't be reached but needed for async generator
serving_embedding.tokenizer_manager.generate_request = Mock(
self.serving_embedding.tokenizer_manager.generate_request = Mock(
return_value=mock_generate_error()
)
response = await serving_embedding.handle_request(
basic_embedding_request, mock_request
response = await self.serving_embedding.handle_request(
self.basic_req, self.request
)
assert isinstance(response, ORJSONResponse)
assert response.status_code == 400
self.assertIsInstance(response, ORJSONResponse)
self.assertEqual(response.status_code, 400)
async def test_handle_request_internal_error(
self, serving_embedding, basic_embedding_request, mock_request
):
async def test_handle_request_internal_error(self):
"""Test handling request with internal server error."""
# Mock _convert_to_internal_request to raise an exception
with patch.object(
serving_embedding,
self.serving_embedding,
"_convert_to_internal_request",
side_effect=Exception("Internal error"),
):
response = await serving_embedding.handle_request(
basic_embedding_request, mock_request
response = await self.serving_embedding.handle_request(
self.basic_req, self.request
)
assert isinstance(response, ORJSONResponse)
assert response.status_code == 500
self.assertIsInstance(response, ORJSONResponse)
self.assertEqual(response.status_code, 500)
if __name__ == "__main__":
unittest.main(verbosity=2)