Add more refactored openai test & in CI (#7284)
This commit is contained in:
@@ -36,7 +36,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||||||
from fastapi.responses import Response
|
from fastapi.responses import Response
|
||||||
|
|
||||||
from sglang.srt.disaggregation.utils import (
|
from sglang.srt.disaggregation.utils import (
|
||||||
FakeBootstrapHost,
|
FAKE_BOOTSTRAP_HOST,
|
||||||
register_disaggregation_server,
|
register_disaggregation_server,
|
||||||
)
|
)
|
||||||
from sglang.srt.entrypoints.engine import Engine, _launch_subprocesses
|
from sglang.srt.entrypoints.engine import Engine, _launch_subprocesses
|
||||||
@@ -265,7 +265,7 @@ def _wait_and_warmup(
|
|||||||
"max_new_tokens": 8,
|
"max_new_tokens": 8,
|
||||||
"ignore_eos": True,
|
"ignore_eos": True,
|
||||||
},
|
},
|
||||||
"bootstrap_host": [FakeBootstrapHost] * server_args.dp_size,
|
"bootstrap_host": [FAKE_BOOTSTRAP_HOST] * server_args.dp_size,
|
||||||
# This is a hack to ensure fake transfer is enabled during prefill warmup
|
# This is a hack to ensure fake transfer is enabled during prefill warmup
|
||||||
# ensure each dp rank has a unique bootstrap_room during prefill warmup
|
# ensure each dp rank has a unique bootstrap_room during prefill warmup
|
||||||
"bootstrap_room": [
|
"bootstrap_room": [
|
||||||
|
|||||||
@@ -12,9 +12,10 @@ import pytest
|
|||||||
import requests
|
import requests
|
||||||
|
|
||||||
from sglang.srt.utils import kill_process_tree # reuse SGLang helper
|
from sglang.srt.utils import kill_process_tree # reuse SGLang helper
|
||||||
|
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
|
|
||||||
SERVER_MODULE = "sglang.srt.entrypoints.openai.api_server"
|
SERVER_MODULE = "sglang.srt.entrypoints.openai.api_server"
|
||||||
DEFAULT_MODEL = "dummy-model"
|
DEFAULT_MODEL = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
STARTUP_TIMEOUT = float(os.getenv("SGLANG_OPENAI_STARTUP_TIMEOUT", 120))
|
STARTUP_TIMEOUT = float(os.getenv("SGLANG_OPENAI_STARTUP_TIMEOUT", 120))
|
||||||
|
|
||||||
|
|
||||||
@@ -39,7 +40,7 @@ def _wait_until_healthy(proc: subprocess.Popen, base: str, timeout: float) -> No
|
|||||||
|
|
||||||
|
|
||||||
def launch_openai_server(model: str = DEFAULT_MODEL, **kw):
|
def launch_openai_server(model: str = DEFAULT_MODEL, **kw):
|
||||||
"""Spawn the draft OpenAI-compatible server and wait until it’s ready."""
|
"""Spawn the draft OpenAI-compatible server and wait until it's ready."""
|
||||||
port = _pick_free_port()
|
port = _pick_free_port()
|
||||||
cmd = [
|
cmd = [
|
||||||
sys.executable,
|
sys.executable,
|
||||||
@@ -79,7 +80,7 @@ def launch_openai_server(model: str = DEFAULT_MODEL, **kw):
|
|||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def openai_server() -> Generator[str, None, None]:
|
def openai_server() -> Generator[str, None, None]:
|
||||||
"""PyTest fixture that provides the server’s base URL and cleans up."""
|
"""PyTest fixture that provides the server's base URL and cleans up."""
|
||||||
proc, base, log_file = launch_openai_server()
|
proc, base, log_file = launch_openai_server()
|
||||||
yield base
|
yield base
|
||||||
kill_process_tree(proc.pid)
|
kill_process_tree(proc.pid)
|
||||||
|
|||||||
@@ -15,9 +15,9 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
import unittest
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import pytest
|
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from sglang.srt.entrypoints.openai.protocol import (
|
from sglang.srt.entrypoints.openai.protocol import (
|
||||||
@@ -64,18 +64,18 @@ from sglang.srt.entrypoints.openai.protocol import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestModelCard:
|
class TestModelCard(unittest.TestCase):
|
||||||
"""Test ModelCard protocol model"""
|
"""Test ModelCard protocol model"""
|
||||||
|
|
||||||
def test_basic_model_card_creation(self):
|
def test_basic_model_card_creation(self):
|
||||||
"""Test basic model card creation with required fields"""
|
"""Test basic model card creation with required fields"""
|
||||||
card = ModelCard(id="test-model")
|
card = ModelCard(id="test-model")
|
||||||
assert card.id == "test-model"
|
self.assertEqual(card.id, "test-model")
|
||||||
assert card.object == "model"
|
self.assertEqual(card.object, "model")
|
||||||
assert card.owned_by == "sglang"
|
self.assertEqual(card.owned_by, "sglang")
|
||||||
assert isinstance(card.created, int)
|
self.assertIsInstance(card.created, int)
|
||||||
assert card.root is None
|
self.assertIsNone(card.root)
|
||||||
assert card.max_model_len is None
|
self.assertIsNone(card.max_model_len)
|
||||||
|
|
||||||
def test_model_card_with_optional_fields(self):
|
def test_model_card_with_optional_fields(self):
|
||||||
"""Test model card with optional fields"""
|
"""Test model card with optional fields"""
|
||||||
@@ -85,28 +85,28 @@ class TestModelCard:
|
|||||||
max_model_len=2048,
|
max_model_len=2048,
|
||||||
created=1234567890,
|
created=1234567890,
|
||||||
)
|
)
|
||||||
assert card.id == "test-model"
|
self.assertEqual(card.id, "test-model")
|
||||||
assert card.root == "/path/to/model"
|
self.assertEqual(card.root, "/path/to/model")
|
||||||
assert card.max_model_len == 2048
|
self.assertEqual(card.max_model_len, 2048)
|
||||||
assert card.created == 1234567890
|
self.assertEqual(card.created, 1234567890)
|
||||||
|
|
||||||
def test_model_card_serialization(self):
|
def test_model_card_serialization(self):
|
||||||
"""Test model card JSON serialization"""
|
"""Test model card JSON serialization"""
|
||||||
card = ModelCard(id="test-model", max_model_len=4096)
|
card = ModelCard(id="test-model", max_model_len=4096)
|
||||||
data = card.model_dump()
|
data = card.model_dump()
|
||||||
assert data["id"] == "test-model"
|
self.assertEqual(data["id"], "test-model")
|
||||||
assert data["object"] == "model"
|
self.assertEqual(data["object"], "model")
|
||||||
assert data["max_model_len"] == 4096
|
self.assertEqual(data["max_model_len"], 4096)
|
||||||
|
|
||||||
|
|
||||||
class TestModelList:
|
class TestModelList(unittest.TestCase):
|
||||||
"""Test ModelList protocol model"""
|
"""Test ModelList protocol model"""
|
||||||
|
|
||||||
def test_empty_model_list(self):
|
def test_empty_model_list(self):
|
||||||
"""Test empty model list creation"""
|
"""Test empty model list creation"""
|
||||||
model_list = ModelList()
|
model_list = ModelList()
|
||||||
assert model_list.object == "list"
|
self.assertEqual(model_list.object, "list")
|
||||||
assert len(model_list.data) == 0
|
self.assertEqual(len(model_list.data), 0)
|
||||||
|
|
||||||
def test_model_list_with_cards(self):
|
def test_model_list_with_cards(self):
|
||||||
"""Test model list with model cards"""
|
"""Test model list with model cards"""
|
||||||
@@ -115,12 +115,12 @@ class TestModelList:
|
|||||||
ModelCard(id="model-2", max_model_len=2048),
|
ModelCard(id="model-2", max_model_len=2048),
|
||||||
]
|
]
|
||||||
model_list = ModelList(data=cards)
|
model_list = ModelList(data=cards)
|
||||||
assert len(model_list.data) == 2
|
self.assertEqual(len(model_list.data), 2)
|
||||||
assert model_list.data[0].id == "model-1"
|
self.assertEqual(model_list.data[0].id, "model-1")
|
||||||
assert model_list.data[1].id == "model-2"
|
self.assertEqual(model_list.data[1].id, "model-2")
|
||||||
|
|
||||||
|
|
||||||
class TestErrorResponse:
|
class TestErrorResponse(unittest.TestCase):
|
||||||
"""Test ErrorResponse protocol model"""
|
"""Test ErrorResponse protocol model"""
|
||||||
|
|
||||||
def test_basic_error_response(self):
|
def test_basic_error_response(self):
|
||||||
@@ -128,11 +128,11 @@ class TestErrorResponse:
|
|||||||
error = ErrorResponse(
|
error = ErrorResponse(
|
||||||
message="Invalid request", type="BadRequestError", code=400
|
message="Invalid request", type="BadRequestError", code=400
|
||||||
)
|
)
|
||||||
assert error.object == "error"
|
self.assertEqual(error.object, "error")
|
||||||
assert error.message == "Invalid request"
|
self.assertEqual(error.message, "Invalid request")
|
||||||
assert error.type == "BadRequestError"
|
self.assertEqual(error.type, "BadRequestError")
|
||||||
assert error.code == 400
|
self.assertEqual(error.code, 400)
|
||||||
assert error.param is None
|
self.assertIsNone(error.param)
|
||||||
|
|
||||||
def test_error_response_with_param(self):
|
def test_error_response_with_param(self):
|
||||||
"""Test error response with parameter"""
|
"""Test error response with parameter"""
|
||||||
@@ -142,19 +142,19 @@ class TestErrorResponse:
|
|||||||
code=422,
|
code=422,
|
||||||
param="temperature",
|
param="temperature",
|
||||||
)
|
)
|
||||||
assert error.param == "temperature"
|
self.assertEqual(error.param, "temperature")
|
||||||
|
|
||||||
|
|
||||||
class TestUsageInfo:
|
class TestUsageInfo(unittest.TestCase):
|
||||||
"""Test UsageInfo protocol model"""
|
"""Test UsageInfo protocol model"""
|
||||||
|
|
||||||
def test_basic_usage_info(self):
|
def test_basic_usage_info(self):
|
||||||
"""Test basic usage info creation"""
|
"""Test basic usage info creation"""
|
||||||
usage = UsageInfo(prompt_tokens=10, completion_tokens=20, total_tokens=30)
|
usage = UsageInfo(prompt_tokens=10, completion_tokens=20, total_tokens=30)
|
||||||
assert usage.prompt_tokens == 10
|
self.assertEqual(usage.prompt_tokens, 10)
|
||||||
assert usage.completion_tokens == 20
|
self.assertEqual(usage.completion_tokens, 20)
|
||||||
assert usage.total_tokens == 30
|
self.assertEqual(usage.total_tokens, 30)
|
||||||
assert usage.prompt_tokens_details is None
|
self.assertIsNone(usage.prompt_tokens_details)
|
||||||
|
|
||||||
def test_usage_info_with_cache_details(self):
|
def test_usage_info_with_cache_details(self):
|
||||||
"""Test usage info with cache details"""
|
"""Test usage info with cache details"""
|
||||||
@@ -164,22 +164,22 @@ class TestUsageInfo:
|
|||||||
total_tokens=30,
|
total_tokens=30,
|
||||||
prompt_tokens_details={"cached_tokens": 5},
|
prompt_tokens_details={"cached_tokens": 5},
|
||||||
)
|
)
|
||||||
assert usage.prompt_tokens_details == {"cached_tokens": 5}
|
self.assertEqual(usage.prompt_tokens_details, {"cached_tokens": 5})
|
||||||
|
|
||||||
|
|
||||||
class TestCompletionRequest:
|
class TestCompletionRequest(unittest.TestCase):
|
||||||
"""Test CompletionRequest protocol model"""
|
"""Test CompletionRequest protocol model"""
|
||||||
|
|
||||||
def test_basic_completion_request(self):
|
def test_basic_completion_request(self):
|
||||||
"""Test basic completion request"""
|
"""Test basic completion request"""
|
||||||
request = CompletionRequest(model="test-model", prompt="Hello world")
|
request = CompletionRequest(model="test-model", prompt="Hello world")
|
||||||
assert request.model == "test-model"
|
self.assertEqual(request.model, "test-model")
|
||||||
assert request.prompt == "Hello world"
|
self.assertEqual(request.prompt, "Hello world")
|
||||||
assert request.max_tokens == 16 # default
|
self.assertEqual(request.max_tokens, 16) # default
|
||||||
assert request.temperature == 1.0 # default
|
self.assertEqual(request.temperature, 1.0) # default
|
||||||
assert request.n == 1 # default
|
self.assertEqual(request.n, 1) # default
|
||||||
assert not request.stream # default
|
self.assertFalse(request.stream) # default
|
||||||
assert not request.echo # default
|
self.assertFalse(request.echo) # default
|
||||||
|
|
||||||
def test_completion_request_with_options(self):
|
def test_completion_request_with_options(self):
|
||||||
"""Test completion request with various options"""
|
"""Test completion request with various options"""
|
||||||
@@ -195,15 +195,15 @@ class TestCompletionRequest:
|
|||||||
stop=[".", "!"],
|
stop=[".", "!"],
|
||||||
logprobs=5,
|
logprobs=5,
|
||||||
)
|
)
|
||||||
assert request.prompt == ["Hello", "world"]
|
self.assertEqual(request.prompt, ["Hello", "world"])
|
||||||
assert request.max_tokens == 100
|
self.assertEqual(request.max_tokens, 100)
|
||||||
assert request.temperature == 0.7
|
self.assertEqual(request.temperature, 0.7)
|
||||||
assert request.top_p == 0.9
|
self.assertEqual(request.top_p, 0.9)
|
||||||
assert request.n == 2
|
self.assertEqual(request.n, 2)
|
||||||
assert request.stream
|
self.assertTrue(request.stream)
|
||||||
assert request.echo
|
self.assertTrue(request.echo)
|
||||||
assert request.stop == [".", "!"]
|
self.assertEqual(request.stop, [".", "!"])
|
||||||
assert request.logprobs == 5
|
self.assertEqual(request.logprobs, 5)
|
||||||
|
|
||||||
def test_completion_request_sglang_extensions(self):
|
def test_completion_request_sglang_extensions(self):
|
||||||
"""Test completion request with SGLang-specific extensions"""
|
"""Test completion request with SGLang-specific extensions"""
|
||||||
@@ -217,23 +217,23 @@ class TestCompletionRequest:
|
|||||||
json_schema='{"type": "object"}',
|
json_schema='{"type": "object"}',
|
||||||
lora_path="/path/to/lora",
|
lora_path="/path/to/lora",
|
||||||
)
|
)
|
||||||
assert request.top_k == 50
|
self.assertEqual(request.top_k, 50)
|
||||||
assert request.min_p == 0.1
|
self.assertEqual(request.min_p, 0.1)
|
||||||
assert request.repetition_penalty == 1.1
|
self.assertEqual(request.repetition_penalty, 1.1)
|
||||||
assert request.regex == r"\d+"
|
self.assertEqual(request.regex, r"\d+")
|
||||||
assert request.json_schema == '{"type": "object"}'
|
self.assertEqual(request.json_schema, '{"type": "object"}')
|
||||||
assert request.lora_path == "/path/to/lora"
|
self.assertEqual(request.lora_path, "/path/to/lora")
|
||||||
|
|
||||||
def test_completion_request_validation_errors(self):
|
def test_completion_request_validation_errors(self):
|
||||||
"""Test completion request validation errors"""
|
"""Test completion request validation errors"""
|
||||||
with pytest.raises(ValidationError):
|
with self.assertRaises(ValidationError):
|
||||||
CompletionRequest() # missing required fields
|
CompletionRequest() # missing required fields
|
||||||
|
|
||||||
with pytest.raises(ValidationError):
|
with self.assertRaises(ValidationError):
|
||||||
CompletionRequest(model="test-model") # missing prompt
|
CompletionRequest(model="test-model") # missing prompt
|
||||||
|
|
||||||
|
|
||||||
class TestCompletionResponse:
|
class TestCompletionResponse(unittest.TestCase):
|
||||||
"""Test CompletionResponse protocol model"""
|
"""Test CompletionResponse protocol model"""
|
||||||
|
|
||||||
def test_basic_completion_response(self):
|
def test_basic_completion_response(self):
|
||||||
@@ -245,28 +245,28 @@ class TestCompletionResponse:
|
|||||||
response = CompletionResponse(
|
response = CompletionResponse(
|
||||||
id="test-id", model="test-model", choices=[choice], usage=usage
|
id="test-id", model="test-model", choices=[choice], usage=usage
|
||||||
)
|
)
|
||||||
assert response.id == "test-id"
|
self.assertEqual(response.id, "test-id")
|
||||||
assert response.object == "text_completion"
|
self.assertEqual(response.object, "text_completion")
|
||||||
assert response.model == "test-model"
|
self.assertEqual(response.model, "test-model")
|
||||||
assert len(response.choices) == 1
|
self.assertEqual(len(response.choices), 1)
|
||||||
assert response.choices[0].text == "Hello world!"
|
self.assertEqual(response.choices[0].text, "Hello world!")
|
||||||
assert response.usage.total_tokens == 5
|
self.assertEqual(response.usage.total_tokens, 5)
|
||||||
|
|
||||||
|
|
||||||
class TestChatCompletionRequest:
|
class TestChatCompletionRequest(unittest.TestCase):
|
||||||
"""Test ChatCompletionRequest protocol model"""
|
"""Test ChatCompletionRequest protocol model"""
|
||||||
|
|
||||||
def test_basic_chat_completion_request(self):
|
def test_basic_chat_completion_request(self):
|
||||||
"""Test basic chat completion request"""
|
"""Test basic chat completion request"""
|
||||||
messages = [{"role": "user", "content": "Hello"}]
|
messages = [{"role": "user", "content": "Hello"}]
|
||||||
request = ChatCompletionRequest(model="test-model", messages=messages)
|
request = ChatCompletionRequest(model="test-model", messages=messages)
|
||||||
assert request.model == "test-model"
|
self.assertEqual(request.model, "test-model")
|
||||||
assert len(request.messages) == 1
|
self.assertEqual(len(request.messages), 1)
|
||||||
assert request.messages[0].role == "user"
|
self.assertEqual(request.messages[0].role, "user")
|
||||||
assert request.messages[0].content == "Hello"
|
self.assertEqual(request.messages[0].content, "Hello")
|
||||||
assert request.temperature == 0.7 # default
|
self.assertEqual(request.temperature, 0.7) # default
|
||||||
assert not request.stream # default
|
self.assertFalse(request.stream) # default
|
||||||
assert request.tool_choice == "none" # default when no tools
|
self.assertEqual(request.tool_choice, "none") # default when no tools
|
||||||
|
|
||||||
def test_chat_completion_with_multimodal_content(self):
|
def test_chat_completion_with_multimodal_content(self):
|
||||||
"""Test chat completion with multimodal content"""
|
"""Test chat completion with multimodal content"""
|
||||||
@@ -283,9 +283,9 @@ class TestChatCompletionRequest:
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
request = ChatCompletionRequest(model="test-model", messages=messages)
|
request = ChatCompletionRequest(model="test-model", messages=messages)
|
||||||
assert len(request.messages[0].content) == 2
|
self.assertEqual(len(request.messages[0].content), 2)
|
||||||
assert request.messages[0].content[0].type == "text"
|
self.assertEqual(request.messages[0].content[0].type, "text")
|
||||||
assert request.messages[0].content[1].type == "image_url"
|
self.assertEqual(request.messages[0].content[1].type, "image_url")
|
||||||
|
|
||||||
def test_chat_completion_with_tools(self):
|
def test_chat_completion_with_tools(self):
|
||||||
"""Test chat completion with tools"""
|
"""Test chat completion with tools"""
|
||||||
@@ -306,9 +306,9 @@ class TestChatCompletionRequest:
|
|||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model="test-model", messages=messages, tools=tools
|
model="test-model", messages=messages, tools=tools
|
||||||
)
|
)
|
||||||
assert len(request.tools) == 1
|
self.assertEqual(len(request.tools), 1)
|
||||||
assert request.tools[0].function.name == "get_weather"
|
self.assertEqual(request.tools[0].function.name, "get_weather")
|
||||||
assert request.tool_choice == "auto" # default when tools present
|
self.assertEqual(request.tool_choice, "auto") # default when tools present
|
||||||
|
|
||||||
def test_chat_completion_tool_choice_validation(self):
|
def test_chat_completion_tool_choice_validation(self):
|
||||||
"""Test tool choice validation logic"""
|
"""Test tool choice validation logic"""
|
||||||
@@ -316,7 +316,7 @@ class TestChatCompletionRequest:
|
|||||||
|
|
||||||
# No tools, tool_choice should default to "none"
|
# No tools, tool_choice should default to "none"
|
||||||
request1 = ChatCompletionRequest(model="test-model", messages=messages)
|
request1 = ChatCompletionRequest(model="test-model", messages=messages)
|
||||||
assert request1.tool_choice == "none"
|
self.assertEqual(request1.tool_choice, "none")
|
||||||
|
|
||||||
# With tools, tool_choice should default to "auto"
|
# With tools, tool_choice should default to "auto"
|
||||||
tools = [
|
tools = [
|
||||||
@@ -328,7 +328,7 @@ class TestChatCompletionRequest:
|
|||||||
request2 = ChatCompletionRequest(
|
request2 = ChatCompletionRequest(
|
||||||
model="test-model", messages=messages, tools=tools
|
model="test-model", messages=messages, tools=tools
|
||||||
)
|
)
|
||||||
assert request2.tool_choice == "auto"
|
self.assertEqual(request2.tool_choice, "auto")
|
||||||
|
|
||||||
def test_chat_completion_sglang_extensions(self):
|
def test_chat_completion_sglang_extensions(self):
|
||||||
"""Test chat completion with SGLang extensions"""
|
"""Test chat completion with SGLang extensions"""
|
||||||
@@ -342,14 +342,14 @@ class TestChatCompletionRequest:
|
|||||||
stream_reasoning=False,
|
stream_reasoning=False,
|
||||||
chat_template_kwargs={"custom_param": "value"},
|
chat_template_kwargs={"custom_param": "value"},
|
||||||
)
|
)
|
||||||
assert request.top_k == 40
|
self.assertEqual(request.top_k, 40)
|
||||||
assert request.min_p == 0.05
|
self.assertEqual(request.min_p, 0.05)
|
||||||
assert not request.separate_reasoning
|
self.assertFalse(request.separate_reasoning)
|
||||||
assert not request.stream_reasoning
|
self.assertFalse(request.stream_reasoning)
|
||||||
assert request.chat_template_kwargs == {"custom_param": "value"}
|
self.assertEqual(request.chat_template_kwargs, {"custom_param": "value"})
|
||||||
|
|
||||||
|
|
||||||
class TestChatCompletionResponse:
|
class TestChatCompletionResponse(unittest.TestCase):
|
||||||
"""Test ChatCompletionResponse protocol model"""
|
"""Test ChatCompletionResponse protocol model"""
|
||||||
|
|
||||||
def test_basic_chat_completion_response(self):
|
def test_basic_chat_completion_response(self):
|
||||||
@@ -362,11 +362,11 @@ class TestChatCompletionResponse:
|
|||||||
response = ChatCompletionResponse(
|
response = ChatCompletionResponse(
|
||||||
id="test-id", model="test-model", choices=[choice], usage=usage
|
id="test-id", model="test-model", choices=[choice], usage=usage
|
||||||
)
|
)
|
||||||
assert response.id == "test-id"
|
self.assertEqual(response.id, "test-id")
|
||||||
assert response.object == "chat.completion"
|
self.assertEqual(response.object, "chat.completion")
|
||||||
assert response.model == "test-model"
|
self.assertEqual(response.model, "test-model")
|
||||||
assert len(response.choices) == 1
|
self.assertEqual(len(response.choices), 1)
|
||||||
assert response.choices[0].message.content == "Hello there!"
|
self.assertEqual(response.choices[0].message.content, "Hello there!")
|
||||||
|
|
||||||
def test_chat_completion_response_with_tool_calls(self):
|
def test_chat_completion_response_with_tool_calls(self):
|
||||||
"""Test chat completion response with tool calls"""
|
"""Test chat completion response with tool calls"""
|
||||||
@@ -384,28 +384,30 @@ class TestChatCompletionResponse:
|
|||||||
response = ChatCompletionResponse(
|
response = ChatCompletionResponse(
|
||||||
id="test-id", model="test-model", choices=[choice], usage=usage
|
id="test-id", model="test-model", choices=[choice], usage=usage
|
||||||
)
|
)
|
||||||
assert response.choices[0].message.tool_calls[0].function.name == "get_weather"
|
self.assertEqual(
|
||||||
assert response.choices[0].finish_reason == "tool_calls"
|
response.choices[0].message.tool_calls[0].function.name, "get_weather"
|
||||||
|
)
|
||||||
|
self.assertEqual(response.choices[0].finish_reason, "tool_calls")
|
||||||
|
|
||||||
|
|
||||||
class TestEmbeddingRequest:
|
class TestEmbeddingRequest(unittest.TestCase):
|
||||||
"""Test EmbeddingRequest protocol model"""
|
"""Test EmbeddingRequest protocol model"""
|
||||||
|
|
||||||
def test_basic_embedding_request(self):
|
def test_basic_embedding_request(self):
|
||||||
"""Test basic embedding request"""
|
"""Test basic embedding request"""
|
||||||
request = EmbeddingRequest(model="test-model", input="Hello world")
|
request = EmbeddingRequest(model="test-model", input="Hello world")
|
||||||
assert request.model == "test-model"
|
self.assertEqual(request.model, "test-model")
|
||||||
assert request.input == "Hello world"
|
self.assertEqual(request.input, "Hello world")
|
||||||
assert request.encoding_format == "float" # default
|
self.assertEqual(request.encoding_format, "float") # default
|
||||||
assert request.dimensions is None # default
|
self.assertIsNone(request.dimensions) # default
|
||||||
|
|
||||||
def test_embedding_request_with_list_input(self):
|
def test_embedding_request_with_list_input(self):
|
||||||
"""Test embedding request with list input"""
|
"""Test embedding request with list input"""
|
||||||
request = EmbeddingRequest(
|
request = EmbeddingRequest(
|
||||||
model="test-model", input=["Hello", "world"], dimensions=512
|
model="test-model", input=["Hello", "world"], dimensions=512
|
||||||
)
|
)
|
||||||
assert request.input == ["Hello", "world"]
|
self.assertEqual(request.input, ["Hello", "world"])
|
||||||
assert request.dimensions == 512
|
self.assertEqual(request.dimensions, 512)
|
||||||
|
|
||||||
def test_multimodal_embedding_request(self):
|
def test_multimodal_embedding_request(self):
|
||||||
"""Test multimodal embedding request"""
|
"""Test multimodal embedding request"""
|
||||||
@@ -414,14 +416,14 @@ class TestEmbeddingRequest:
|
|||||||
MultimodalEmbeddingInput(text="World", image=None),
|
MultimodalEmbeddingInput(text="World", image=None),
|
||||||
]
|
]
|
||||||
request = EmbeddingRequest(model="test-model", input=multimodal_input)
|
request = EmbeddingRequest(model="test-model", input=multimodal_input)
|
||||||
assert len(request.input) == 2
|
self.assertEqual(len(request.input), 2)
|
||||||
assert request.input[0].text == "Hello"
|
self.assertEqual(request.input[0].text, "Hello")
|
||||||
assert request.input[0].image == "base64_image_data"
|
self.assertEqual(request.input[0].image, "base64_image_data")
|
||||||
assert request.input[1].text == "World"
|
self.assertEqual(request.input[1].text, "World")
|
||||||
assert request.input[1].image is None
|
self.assertIsNone(request.input[1].image)
|
||||||
|
|
||||||
|
|
||||||
class TestEmbeddingResponse:
|
class TestEmbeddingResponse(unittest.TestCase):
|
||||||
"""Test EmbeddingResponse protocol model"""
|
"""Test EmbeddingResponse protocol model"""
|
||||||
|
|
||||||
def test_basic_embedding_response(self):
|
def test_basic_embedding_response(self):
|
||||||
@@ -431,14 +433,14 @@ class TestEmbeddingResponse:
|
|||||||
response = EmbeddingResponse(
|
response = EmbeddingResponse(
|
||||||
data=[embedding_obj], model="test-model", usage=usage
|
data=[embedding_obj], model="test-model", usage=usage
|
||||||
)
|
)
|
||||||
assert response.object == "list"
|
self.assertEqual(response.object, "list")
|
||||||
assert len(response.data) == 1
|
self.assertEqual(len(response.data), 1)
|
||||||
assert response.data[0].embedding == [0.1, 0.2, 0.3]
|
self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3])
|
||||||
assert response.data[0].index == 0
|
self.assertEqual(response.data[0].index, 0)
|
||||||
assert response.usage.prompt_tokens == 3
|
self.assertEqual(response.usage.prompt_tokens, 3)
|
||||||
|
|
||||||
|
|
||||||
class TestScoringRequest:
|
class TestScoringRequest(unittest.TestCase):
|
||||||
"""Test ScoringRequest protocol model"""
|
"""Test ScoringRequest protocol model"""
|
||||||
|
|
||||||
def test_basic_scoring_request(self):
|
def test_basic_scoring_request(self):
|
||||||
@@ -446,11 +448,11 @@ class TestScoringRequest:
|
|||||||
request = ScoringRequest(
|
request = ScoringRequest(
|
||||||
model="test-model", query="Hello", items=["World", "Earth"]
|
model="test-model", query="Hello", items=["World", "Earth"]
|
||||||
)
|
)
|
||||||
assert request.model == "test-model"
|
self.assertEqual(request.model, "test-model")
|
||||||
assert request.query == "Hello"
|
self.assertEqual(request.query, "Hello")
|
||||||
assert request.items == ["World", "Earth"]
|
self.assertEqual(request.items, ["World", "Earth"])
|
||||||
assert not request.apply_softmax # default
|
self.assertFalse(request.apply_softmax) # default
|
||||||
assert not request.item_first # default
|
self.assertFalse(request.item_first) # default
|
||||||
|
|
||||||
def test_scoring_request_with_token_ids(self):
|
def test_scoring_request_with_token_ids(self):
|
||||||
"""Test scoring request with token IDs"""
|
"""Test scoring request with token IDs"""
|
||||||
@@ -462,34 +464,34 @@ class TestScoringRequest:
|
|||||||
apply_softmax=True,
|
apply_softmax=True,
|
||||||
item_first=True,
|
item_first=True,
|
||||||
)
|
)
|
||||||
assert request.query == [1, 2, 3]
|
self.assertEqual(request.query, [1, 2, 3])
|
||||||
assert request.items == [[4, 5], [6, 7]]
|
self.assertEqual(request.items, [[4, 5], [6, 7]])
|
||||||
assert request.label_token_ids == [8, 9]
|
self.assertEqual(request.label_token_ids, [8, 9])
|
||||||
assert request.apply_softmax
|
self.assertTrue(request.apply_softmax)
|
||||||
assert request.item_first
|
self.assertTrue(request.item_first)
|
||||||
|
|
||||||
|
|
||||||
class TestScoringResponse:
|
class TestScoringResponse(unittest.TestCase):
|
||||||
"""Test ScoringResponse protocol model"""
|
"""Test ScoringResponse protocol model"""
|
||||||
|
|
||||||
def test_basic_scoring_response(self):
|
def test_basic_scoring_response(self):
|
||||||
"""Test basic scoring response"""
|
"""Test basic scoring response"""
|
||||||
response = ScoringResponse(scores=[[0.1, 0.9], [0.3, 0.7]], model="test-model")
|
response = ScoringResponse(scores=[[0.1, 0.9], [0.3, 0.7]], model="test-model")
|
||||||
assert response.object == "scoring"
|
self.assertEqual(response.object, "scoring")
|
||||||
assert response.scores == [[0.1, 0.9], [0.3, 0.7]]
|
self.assertEqual(response.scores, [[0.1, 0.9], [0.3, 0.7]])
|
||||||
assert response.model == "test-model"
|
self.assertEqual(response.model, "test-model")
|
||||||
assert response.usage is None # default
|
self.assertIsNone(response.usage) # default
|
||||||
|
|
||||||
|
|
||||||
class TestFileOperations:
|
class TestFileOperations(unittest.TestCase):
|
||||||
"""Test file operation protocol models"""
|
"""Test file operation protocol models"""
|
||||||
|
|
||||||
def test_file_request(self):
|
def test_file_request(self):
|
||||||
"""Test file request model"""
|
"""Test file request model"""
|
||||||
file_data = b"test file content"
|
file_data = b"test file content"
|
||||||
request = FileRequest(file=file_data, purpose="batch")
|
request = FileRequest(file=file_data, purpose="batch")
|
||||||
assert request.file == file_data
|
self.assertEqual(request.file, file_data)
|
||||||
assert request.purpose == "batch"
|
self.assertEqual(request.purpose, "batch")
|
||||||
|
|
||||||
def test_file_response(self):
|
def test_file_response(self):
|
||||||
"""Test file response model"""
|
"""Test file response model"""
|
||||||
@@ -500,20 +502,20 @@ class TestFileOperations:
|
|||||||
filename="test.jsonl",
|
filename="test.jsonl",
|
||||||
purpose="batch",
|
purpose="batch",
|
||||||
)
|
)
|
||||||
assert response.id == "file-123"
|
self.assertEqual(response.id, "file-123")
|
||||||
assert response.object == "file"
|
self.assertEqual(response.object, "file")
|
||||||
assert response.bytes == 1024
|
self.assertEqual(response.bytes, 1024)
|
||||||
assert response.filename == "test.jsonl"
|
self.assertEqual(response.filename, "test.jsonl")
|
||||||
|
|
||||||
def test_file_delete_response(self):
|
def test_file_delete_response(self):
|
||||||
"""Test file delete response model"""
|
"""Test file delete response model"""
|
||||||
response = FileDeleteResponse(id="file-123", deleted=True)
|
response = FileDeleteResponse(id="file-123", deleted=True)
|
||||||
assert response.id == "file-123"
|
self.assertEqual(response.id, "file-123")
|
||||||
assert response.object == "file"
|
self.assertEqual(response.object, "file")
|
||||||
assert response.deleted
|
self.assertTrue(response.deleted)
|
||||||
|
|
||||||
|
|
||||||
class TestBatchOperations:
|
class TestBatchOperations(unittest.TestCase):
|
||||||
"""Test batch operation protocol models"""
|
"""Test batch operation protocol models"""
|
||||||
|
|
||||||
def test_batch_request(self):
|
def test_batch_request(self):
|
||||||
@@ -524,10 +526,10 @@ class TestBatchOperations:
|
|||||||
completion_window="24h",
|
completion_window="24h",
|
||||||
metadata={"custom": "value"},
|
metadata={"custom": "value"},
|
||||||
)
|
)
|
||||||
assert request.input_file_id == "file-123"
|
self.assertEqual(request.input_file_id, "file-123")
|
||||||
assert request.endpoint == "/v1/chat/completions"
|
self.assertEqual(request.endpoint, "/v1/chat/completions")
|
||||||
assert request.completion_window == "24h"
|
self.assertEqual(request.completion_window, "24h")
|
||||||
assert request.metadata == {"custom": "value"}
|
self.assertEqual(request.metadata, {"custom": "value"})
|
||||||
|
|
||||||
def test_batch_response(self):
|
def test_batch_response(self):
|
||||||
"""Test batch response model"""
|
"""Test batch response model"""
|
||||||
@@ -538,20 +540,20 @@ class TestBatchOperations:
|
|||||||
completion_window="24h",
|
completion_window="24h",
|
||||||
created_at=1234567890,
|
created_at=1234567890,
|
||||||
)
|
)
|
||||||
assert response.id == "batch-123"
|
self.assertEqual(response.id, "batch-123")
|
||||||
assert response.object == "batch"
|
self.assertEqual(response.object, "batch")
|
||||||
assert response.status == "validating" # default
|
self.assertEqual(response.status, "validating") # default
|
||||||
assert response.endpoint == "/v1/chat/completions"
|
self.assertEqual(response.endpoint, "/v1/chat/completions")
|
||||||
|
|
||||||
|
|
||||||
class TestResponseFormats:
|
class TestResponseFormats(unittest.TestCase):
|
||||||
"""Test response format protocol models"""
|
"""Test response format protocol models"""
|
||||||
|
|
||||||
def test_basic_response_format(self):
|
def test_basic_response_format(self):
|
||||||
"""Test basic response format"""
|
"""Test basic response format"""
|
||||||
format_obj = ResponseFormat(type="json_object")
|
format_obj = ResponseFormat(type="json_object")
|
||||||
assert format_obj.type == "json_object"
|
self.assertEqual(format_obj.type, "json_object")
|
||||||
assert format_obj.json_schema is None
|
self.assertIsNone(format_obj.json_schema)
|
||||||
|
|
||||||
def test_json_schema_response_format(self):
|
def test_json_schema_response_format(self):
|
||||||
"""Test JSON schema response format"""
|
"""Test JSON schema response format"""
|
||||||
@@ -560,9 +562,9 @@ class TestResponseFormats:
|
|||||||
name="person_schema", description="Person schema", schema=schema
|
name="person_schema", description="Person schema", schema=schema
|
||||||
)
|
)
|
||||||
format_obj = ResponseFormat(type="json_schema", json_schema=json_schema)
|
format_obj = ResponseFormat(type="json_schema", json_schema=json_schema)
|
||||||
assert format_obj.type == "json_schema"
|
self.assertEqual(format_obj.type, "json_schema")
|
||||||
assert format_obj.json_schema.name == "person_schema"
|
self.assertEqual(format_obj.json_schema.name, "person_schema")
|
||||||
assert format_obj.json_schema.schema_ == schema
|
self.assertEqual(format_obj.json_schema.schema_, schema)
|
||||||
|
|
||||||
def test_structural_tag_response_format(self):
|
def test_structural_tag_response_format(self):
|
||||||
"""Test structural tag response format"""
|
"""Test structural tag response format"""
|
||||||
@@ -576,12 +578,12 @@ class TestResponseFormats:
|
|||||||
format_obj = StructuralTagResponseFormat(
|
format_obj = StructuralTagResponseFormat(
|
||||||
type="structural_tag", structures=structures, triggers=["think"]
|
type="structural_tag", structures=structures, triggers=["think"]
|
||||||
)
|
)
|
||||||
assert format_obj.type == "structural_tag"
|
self.assertEqual(format_obj.type, "structural_tag")
|
||||||
assert len(format_obj.structures) == 1
|
self.assertEqual(len(format_obj.structures), 1)
|
||||||
assert format_obj.triggers == ["think"]
|
self.assertEqual(format_obj.triggers, ["think"])
|
||||||
|
|
||||||
|
|
||||||
class TestLogProbs:
|
class TestLogProbs(unittest.TestCase):
|
||||||
"""Test LogProbs protocol models"""
|
"""Test LogProbs protocol models"""
|
||||||
|
|
||||||
def test_basic_logprobs(self):
|
def test_basic_logprobs(self):
|
||||||
@@ -592,9 +594,9 @@ class TestLogProbs:
|
|||||||
tokens=["Hello", " ", "world"],
|
tokens=["Hello", " ", "world"],
|
||||||
top_logprobs=[{"Hello": -0.1}, {" ": -0.2}, {"world": -0.3}],
|
top_logprobs=[{"Hello": -0.1}, {" ": -0.2}, {"world": -0.3}],
|
||||||
)
|
)
|
||||||
assert len(logprobs.tokens) == 3
|
self.assertEqual(len(logprobs.tokens), 3)
|
||||||
assert logprobs.tokens == ["Hello", " ", "world"]
|
self.assertEqual(logprobs.tokens, ["Hello", " ", "world"])
|
||||||
assert logprobs.token_logprobs == [-0.1, -0.2, -0.3]
|
self.assertEqual(logprobs.token_logprobs, [-0.1, -0.2, -0.3])
|
||||||
|
|
||||||
def test_choice_logprobs(self):
|
def test_choice_logprobs(self):
|
||||||
"""Test ChoiceLogprobs model"""
|
"""Test ChoiceLogprobs model"""
|
||||||
@@ -607,17 +609,17 @@ class TestLogProbs:
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
choice_logprobs = ChoiceLogprobs(content=[token_logprob])
|
choice_logprobs = ChoiceLogprobs(content=[token_logprob])
|
||||||
assert len(choice_logprobs.content) == 1
|
self.assertEqual(len(choice_logprobs.content), 1)
|
||||||
assert choice_logprobs.content[0].token == "Hello"
|
self.assertEqual(choice_logprobs.content[0].token, "Hello")
|
||||||
|
|
||||||
|
|
||||||
class TestStreamingModels:
|
class TestStreamingModels(unittest.TestCase):
|
||||||
"""Test streaming response models"""
|
"""Test streaming response models"""
|
||||||
|
|
||||||
def test_stream_options(self):
|
def test_stream_options(self):
|
||||||
"""Test StreamOptions model"""
|
"""Test StreamOptions model"""
|
||||||
options = StreamOptions(include_usage=True)
|
options = StreamOptions(include_usage=True)
|
||||||
assert options.include_usage
|
self.assertTrue(options.include_usage)
|
||||||
|
|
||||||
def test_chat_completion_stream_response(self):
|
def test_chat_completion_stream_response(self):
|
||||||
"""Test ChatCompletionStreamResponse model"""
|
"""Test ChatCompletionStreamResponse model"""
|
||||||
@@ -626,29 +628,29 @@ class TestStreamingModels:
|
|||||||
response = ChatCompletionStreamResponse(
|
response = ChatCompletionStreamResponse(
|
||||||
id="test-id", model="test-model", choices=[choice]
|
id="test-id", model="test-model", choices=[choice]
|
||||||
)
|
)
|
||||||
assert response.object == "chat.completion.chunk"
|
self.assertEqual(response.object, "chat.completion.chunk")
|
||||||
assert response.choices[0].delta.content == "Hello"
|
self.assertEqual(response.choices[0].delta.content, "Hello")
|
||||||
|
|
||||||
|
|
||||||
class TestValidationEdgeCases:
|
class TestValidationEdgeCases(unittest.TestCase):
|
||||||
"""Test edge cases and validation scenarios"""
|
"""Test edge cases and validation scenarios"""
|
||||||
|
|
||||||
def test_empty_messages_validation(self):
|
def test_empty_messages_validation(self):
|
||||||
"""Test validation with empty messages"""
|
"""Test validation with empty messages"""
|
||||||
with pytest.raises(ValidationError):
|
with self.assertRaises(ValidationError):
|
||||||
ChatCompletionRequest(model="test-model", messages=[])
|
ChatCompletionRequest(model="test-model", messages=[])
|
||||||
|
|
||||||
def test_invalid_tool_choice_type(self):
|
def test_invalid_tool_choice_type(self):
|
||||||
"""Test invalid tool choice type"""
|
"""Test invalid tool choice type"""
|
||||||
messages = [{"role": "user", "content": "Hello"}]
|
messages = [{"role": "user", "content": "Hello"}]
|
||||||
with pytest.raises(ValidationError):
|
with self.assertRaises(ValidationError):
|
||||||
ChatCompletionRequest(
|
ChatCompletionRequest(
|
||||||
model="test-model", messages=messages, tool_choice=123
|
model="test-model", messages=messages, tool_choice=123
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_negative_token_limits(self):
|
def test_negative_token_limits(self):
|
||||||
"""Test negative token limits"""
|
"""Test negative token limits"""
|
||||||
with pytest.raises(ValidationError):
|
with self.assertRaises(ValidationError):
|
||||||
CompletionRequest(model="test-model", prompt="Hello", max_tokens=-1)
|
CompletionRequest(model="test-model", prompt="Hello", max_tokens=-1)
|
||||||
|
|
||||||
def test_invalid_temperature_range(self):
|
def test_invalid_temperature_range(self):
|
||||||
@@ -656,7 +658,7 @@ class TestValidationEdgeCases:
|
|||||||
# Note: The current protocol doesn't enforce temperature range,
|
# Note: The current protocol doesn't enforce temperature range,
|
||||||
# but this test documents expected behavior
|
# but this test documents expected behavior
|
||||||
request = CompletionRequest(model="test-model", prompt="Hello", temperature=5.0)
|
request = CompletionRequest(model="test-model", prompt="Hello", temperature=5.0)
|
||||||
assert request.temperature == 5.0 # Currently allowed
|
self.assertEqual(request.temperature, 5.0) # Currently allowed
|
||||||
|
|
||||||
def test_model_serialization_roundtrip(self):
|
def test_model_serialization_roundtrip(self):
|
||||||
"""Test that models can be serialized and deserialized"""
|
"""Test that models can be serialized and deserialized"""
|
||||||
@@ -673,11 +675,11 @@ class TestValidationEdgeCases:
|
|||||||
# Deserialize back
|
# Deserialize back
|
||||||
restored_request = ChatCompletionRequest(**data)
|
restored_request = ChatCompletionRequest(**data)
|
||||||
|
|
||||||
assert restored_request.model == original_request.model
|
self.assertEqual(restored_request.model, original_request.model)
|
||||||
assert restored_request.temperature == original_request.temperature
|
self.assertEqual(restored_request.temperature, original_request.temperature)
|
||||||
assert restored_request.max_tokens == original_request.max_tokens
|
self.assertEqual(restored_request.max_tokens, original_request.max_tokens)
|
||||||
assert len(restored_request.messages) == len(original_request.messages)
|
self.assertEqual(len(restored_request.messages), len(original_request.messages))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pytest.main([__file__])
|
unittest.main(verbosity=2)
|
||||||
|
|||||||
@@ -1,16 +1,52 @@
|
|||||||
# sglang/test/srt/openai/test_server.py
|
# sglang/test/srt/openai/test_server.py
|
||||||
import pytest
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST as MODEL_ID
|
||||||
|
|
||||||
|
|
||||||
def test_health(openai_server: str):
|
def test_health(openai_server: str):
|
||||||
r = requests.get(f"{openai_server}/health")
|
r = requests.get(f"{openai_server}/health")
|
||||||
assert r.status_code == 200, r.text
|
assert r.status_code == 200
|
||||||
|
# FastAPI returns an empty body → r.text == ""
|
||||||
assert r.text == ""
|
assert r.text == ""
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.xfail(reason="Endpoint skeleton not implemented yet")
|
|
||||||
def test_models_endpoint(openai_server: str):
|
def test_models_endpoint(openai_server: str):
|
||||||
r = requests.get(f"{openai_server}/v1/models")
|
r = requests.get(f"{openai_server}/v1/models")
|
||||||
# once implemented this should be 200
|
assert r.status_code == 200, r.text
|
||||||
assert r.status_code == 200
|
payload = r.json()
|
||||||
|
|
||||||
|
# Basic contract
|
||||||
|
assert "data" in payload and isinstance(payload["data"], list) and payload["data"]
|
||||||
|
|
||||||
|
# Validate fields of the first model card
|
||||||
|
first = payload["data"][0]
|
||||||
|
for key in ("id", "root", "max_model_len"):
|
||||||
|
assert key in first, f"missing {key} in {first}"
|
||||||
|
|
||||||
|
# max_model_len must be positive
|
||||||
|
assert isinstance(first["max_model_len"], int) and first["max_model_len"] > 0
|
||||||
|
|
||||||
|
# The server should report the same model id we launched it with
|
||||||
|
ids = {m["id"] for m in payload["data"]}
|
||||||
|
assert MODEL_ID in ids
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_model_info(openai_server: str):
|
||||||
|
r = requests.get(f"{openai_server}/get_model_info")
|
||||||
|
assert r.status_code == 200, r.text
|
||||||
|
info = r.json()
|
||||||
|
|
||||||
|
expected_keys = {"model_path", "tokenizer_path", "is_generation"}
|
||||||
|
assert expected_keys.issubset(info.keys())
|
||||||
|
|
||||||
|
# model_path must end with the one we passed on the CLI
|
||||||
|
assert info["model_path"].endswith(MODEL_ID)
|
||||||
|
|
||||||
|
# is_generation is documented as a boolean
|
||||||
|
assert isinstance(info["is_generation"], bool)
|
||||||
|
|
||||||
|
|
||||||
|
def test_unknown_route_returns_404(openai_server: str):
|
||||||
|
r = requests.get(f"{openai_server}/definitely-not-a-real-route")
|
||||||
|
assert r.status_code == 404
|
||||||
|
|||||||
@@ -1,41 +1,44 @@
|
|||||||
"""
|
"""
|
||||||
Unit tests for the OpenAIServingChat class from serving_chat.py.
|
Unit-tests for OpenAIServingChat — rewritten to use only the std-lib 'unittest'.
|
||||||
|
Run with either:
|
||||||
These tests ensure that the refactored implementation maintains compatibility
|
python tests/test_serving_chat_unit.py -v
|
||||||
with the original adapter.py functionality.
|
or
|
||||||
|
python -m unittest discover -s tests -p "test_*unit.py" -v
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
import uuid
|
import uuid
|
||||||
|
from typing import Optional
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import pytest
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
|
||||||
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest, ErrorResponse
|
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
|
||||||
from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat
|
from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||||
|
|
||||||
|
|
||||||
# Mock TokenizerManager since it may not be directly importable in tests
|
class _MockTokenizerManager:
|
||||||
class MockTokenizerManager:
|
"""Minimal mock that satisfies OpenAIServingChat."""
|
||||||
def __init__(self):
|
|
||||||
self.model_config = Mock()
|
|
||||||
self.model_config.is_multimodal = False
|
|
||||||
self.server_args = Mock()
|
|
||||||
self.server_args.enable_cache_report = False
|
|
||||||
self.server_args.tool_call_parser = "hermes"
|
|
||||||
self.server_args.reasoning_parser = None
|
|
||||||
self.chat_template_name = "llama-3"
|
|
||||||
|
|
||||||
# Mock tokenizer
|
def __init__(self):
|
||||||
|
self.model_config = Mock(is_multimodal=False)
|
||||||
|
self.server_args = Mock(
|
||||||
|
enable_cache_report=False,
|
||||||
|
tool_call_parser="hermes",
|
||||||
|
reasoning_parser=None,
|
||||||
|
)
|
||||||
|
self.chat_template_name: Optional[str] = "llama-3"
|
||||||
|
|
||||||
|
# tokenizer stub
|
||||||
self.tokenizer = Mock()
|
self.tokenizer = Mock()
|
||||||
self.tokenizer.encode = Mock(return_value=[1, 2, 3, 4, 5])
|
self.tokenizer.encode.return_value = [1, 2, 3, 4, 5]
|
||||||
self.tokenizer.decode = Mock(return_value="Test response")
|
self.tokenizer.decode.return_value = "Test response"
|
||||||
self.tokenizer.chat_template = None
|
self.tokenizer.chat_template = None
|
||||||
self.tokenizer.bos_token_id = 1
|
self.tokenizer.bos_token_id = 1
|
||||||
|
|
||||||
# Mock generate_request method
|
# async generator stub for generate_request
|
||||||
async def mock_generate():
|
async def _mock_generate():
|
||||||
yield {
|
yield {
|
||||||
"text": "Test response",
|
"text": "Test response",
|
||||||
"meta_info": {
|
"meta_info": {
|
||||||
@@ -50,585 +53,176 @@ class MockTokenizerManager:
|
|||||||
"index": 0,
|
"index": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
self.generate_request = Mock(return_value=mock_generate())
|
self.generate_request = Mock(return_value=_mock_generate())
|
||||||
self.create_abort_task = Mock(return_value=None)
|
self.create_abort_task = Mock()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
class ServingChatTestCase(unittest.TestCase):
|
||||||
def mock_tokenizer_manager():
|
# ------------- common fixtures -------------
|
||||||
"""Create a mock tokenizer manager for testing."""
|
def setUp(self):
|
||||||
return MockTokenizerManager()
|
self.tm = _MockTokenizerManager()
|
||||||
|
self.chat = OpenAIServingChat(self.tm)
|
||||||
|
|
||||||
|
# frequently reused requests
|
||||||
|
self.basic_req = ChatCompletionRequest(
|
||||||
|
model="x",
|
||||||
|
messages=[{"role": "user", "content": "Hi?"}],
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=100,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
self.stream_req = ChatCompletionRequest(
|
||||||
|
model="x",
|
||||||
|
messages=[{"role": "user", "content": "Hi?"}],
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=100,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
@pytest.fixture
|
self.fastapi_request = Mock(spec=Request)
|
||||||
def serving_chat(mock_tokenizer_manager):
|
self.fastapi_request.headers = {}
|
||||||
"""Create a OpenAIServingChat instance for testing."""
|
|
||||||
return OpenAIServingChat(mock_tokenizer_manager)
|
|
||||||
|
|
||||||
|
# ------------- conversion tests -------------
|
||||||
@pytest.fixture
|
def test_convert_to_internal_request_single(self):
|
||||||
def mock_request():
|
|
||||||
"""Create a mock FastAPI request."""
|
|
||||||
request = Mock(spec=Request)
|
|
||||||
request.headers = {}
|
|
||||||
return request
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def basic_chat_request():
|
|
||||||
"""Create a basic chat completion request."""
|
|
||||||
return ChatCompletionRequest(
|
|
||||||
model="test-model",
|
|
||||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
|
||||||
temperature=0.7,
|
|
||||||
max_tokens=100,
|
|
||||||
stream=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def streaming_chat_request():
|
|
||||||
"""Create a streaming chat completion request."""
|
|
||||||
return ChatCompletionRequest(
|
|
||||||
model="test-model",
|
|
||||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
|
||||||
temperature=0.7,
|
|
||||||
max_tokens=100,
|
|
||||||
stream=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestOpenAIServingChatConversion:
|
|
||||||
"""Test request conversion methods."""
|
|
||||||
|
|
||||||
def test_convert_to_internal_request_single(
|
|
||||||
self, serving_chat, basic_chat_request, mock_tokenizer_manager
|
|
||||||
):
|
|
||||||
"""Test converting single request to internal format."""
|
|
||||||
with patch(
|
with patch(
|
||||||
"sglang.srt.entrypoints.openai.serving_chat.generate_chat_conv"
|
"sglang.srt.entrypoints.openai.serving_chat.generate_chat_conv"
|
||||||
) as mock_conv:
|
) as conv_mock, patch.object(self.chat, "_process_messages") as proc_mock:
|
||||||
mock_conv_instance = Mock()
|
conv_ins = Mock()
|
||||||
mock_conv_instance.get_prompt.return_value = "Test prompt"
|
conv_ins.get_prompt.return_value = "Test prompt"
|
||||||
mock_conv_instance.image_data = None
|
conv_ins.image_data = conv_ins.audio_data = None
|
||||||
mock_conv_instance.audio_data = None
|
conv_ins.modalities = []
|
||||||
mock_conv_instance.modalities = []
|
conv_ins.stop_str = ["</s>"]
|
||||||
mock_conv_instance.stop_str = ["</s>"]
|
conv_mock.return_value = conv_ins
|
||||||
mock_conv.return_value = mock_conv_instance
|
|
||||||
|
|
||||||
# Mock the _process_messages method to return expected values
|
proc_mock.return_value = (
|
||||||
with patch.object(serving_chat, "_process_messages") as mock_process:
|
"Test prompt",
|
||||||
mock_process.return_value = (
|
[1, 2, 3],
|
||||||
"Test prompt",
|
None,
|
||||||
[1, 2, 3],
|
None,
|
||||||
None,
|
[],
|
||||||
None,
|
["</s>"],
|
||||||
[],
|
None,
|
||||||
["</s>"],
|
)
|
||||||
None, # tool_call_constraint
|
|
||||||
)
|
|
||||||
|
|
||||||
adapted_request, processed_request = (
|
adapted, processed = self.chat._convert_to_internal_request(
|
||||||
serving_chat._convert_to_internal_request(
|
[self.basic_req], ["rid"]
|
||||||
[basic_chat_request], ["test-id"]
|
)
|
||||||
)
|
self.assertIsInstance(adapted, GenerateReqInput)
|
||||||
)
|
self.assertFalse(adapted.stream)
|
||||||
|
self.assertEqual(processed, self.basic_req)
|
||||||
|
|
||||||
assert isinstance(adapted_request, GenerateReqInput)
|
# ------------- tool-call branch -------------
|
||||||
assert adapted_request.stream == basic_chat_request.stream
|
def test_tool_call_request_conversion(self):
|
||||||
assert processed_request == basic_chat_request
|
req = ChatCompletionRequest(
|
||||||
|
model="x",
|
||||||
|
messages=[{"role": "user", "content": "Weather?"}],
|
||||||
class TestToolCalls:
|
|
||||||
"""Test tool call functionality from adapter.py"""
|
|
||||||
|
|
||||||
def test_tool_call_request_conversion(self, serving_chat):
|
|
||||||
"""Test request with tool calls"""
|
|
||||||
request = ChatCompletionRequest(
|
|
||||||
model="test-model",
|
|
||||||
messages=[{"role": "user", "content": "What's the weather?"}],
|
|
||||||
tools=[
|
tools=[
|
||||||
{
|
{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": "get_weather",
|
"name": "get_weather",
|
||||||
"description": "Get weather information",
|
"parameters": {"type": "object", "properties": {}},
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"location": {"type": "string"}},
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
)
|
)
|
||||||
|
|
||||||
with patch.object(serving_chat, "_process_messages") as mock_process:
|
with patch.object(
|
||||||
mock_process.return_value = (
|
self.chat,
|
||||||
"Test prompt",
|
"_process_messages",
|
||||||
[1, 2, 3],
|
return_value=("Prompt", [1, 2, 3], None, None, [], ["</s>"], None),
|
||||||
None,
|
):
|
||||||
None,
|
adapted, _ = self.chat._convert_to_internal_request([req], ["rid"])
|
||||||
[],
|
self.assertEqual(adapted.rid, "rid")
|
||||||
["</s>"],
|
|
||||||
None, # tool_call_constraint
|
|
||||||
)
|
|
||||||
|
|
||||||
adapted_request, _ = serving_chat._convert_to_internal_request(
|
def test_tool_choice_none(self):
|
||||||
[request], ["test-id"]
|
req = ChatCompletionRequest(
|
||||||
)
|
model="x",
|
||||||
|
messages=[{"role": "user", "content": "Hi"}],
|
||||||
assert adapted_request.rid == "test-id"
|
tools=[{"type": "function", "function": {"name": "noop"}}],
|
||||||
# Tool call constraint should be processed
|
|
||||||
assert request.tools is not None
|
|
||||||
|
|
||||||
def test_tool_choice_none(self, serving_chat):
|
|
||||||
"""Test tool_choice=none disables tool calls"""
|
|
||||||
request = ChatCompletionRequest(
|
|
||||||
model="test-model",
|
|
||||||
messages=[{"role": "user", "content": "Hello"}],
|
|
||||||
tools=[{"type": "function", "function": {"name": "test_func"}}],
|
|
||||||
tool_choice="none",
|
tool_choice="none",
|
||||||
)
|
)
|
||||||
|
with patch.object(
|
||||||
|
self.chat,
|
||||||
|
"_process_messages",
|
||||||
|
return_value=("Prompt", [1, 2, 3], None, None, [], ["</s>"], None),
|
||||||
|
):
|
||||||
|
adapted, _ = self.chat._convert_to_internal_request([req], ["rid"])
|
||||||
|
self.assertEqual(adapted.rid, "rid")
|
||||||
|
|
||||||
with patch.object(serving_chat, "_process_messages") as mock_process:
|
# ------------- multimodal branch -------------
|
||||||
mock_process.return_value = (
|
def test_multimodal_request_with_images(self):
|
||||||
"Test prompt",
|
self.tm.model_config.is_multimodal = True
|
||||||
[1, 2, 3],
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
[],
|
|
||||||
["</s>"],
|
|
||||||
None, # tool_call_constraint
|
|
||||||
)
|
|
||||||
|
|
||||||
adapted_request, _ = serving_chat._convert_to_internal_request(
|
req = ChatCompletionRequest(
|
||||||
[request], ["test-id"]
|
model="x",
|
||||||
)
|
|
||||||
|
|
||||||
# Tools should not be processed when tool_choice is "none"
|
|
||||||
assert adapted_request.rid == "test-id"
|
|
||||||
|
|
||||||
def test_tool_call_response_processing(self, serving_chat):
|
|
||||||
"""Test processing tool calls in response"""
|
|
||||||
mock_ret_item = {
|
|
||||||
"text": '{"name": "get_weather", "parameters": {"location": "Paris"}}',
|
|
||||||
"meta_info": {
|
|
||||||
"output_token_logprobs": [],
|
|
||||||
"output_top_logprobs": None,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
tools = [
|
|
||||||
{
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "get_weather",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"location": {"type": "string"}},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
finish_reason = {"type": "stop", "matched": None}
|
|
||||||
|
|
||||||
# Mock FunctionCallParser
|
|
||||||
with patch(
|
|
||||||
"sglang.srt.entrypoints.openai.serving_chat.FunctionCallParser"
|
|
||||||
) as mock_parser_class:
|
|
||||||
mock_parser = Mock()
|
|
||||||
mock_parser.has_tool_call.return_value = True
|
|
||||||
|
|
||||||
# Create proper mock tool call object
|
|
||||||
mock_tool_call = Mock()
|
|
||||||
mock_tool_call.name = "get_weather"
|
|
||||||
mock_tool_call.parameters = '{"location": "Paris"}'
|
|
||||||
|
|
||||||
mock_parser.parse_non_stream.return_value = ("", [mock_tool_call])
|
|
||||||
mock_parser_class.return_value = mock_parser
|
|
||||||
|
|
||||||
tool_calls, text, updated_finish_reason = serving_chat._process_tool_calls(
|
|
||||||
mock_ret_item["text"], tools, "hermes", finish_reason
|
|
||||||
)
|
|
||||||
|
|
||||||
assert tool_calls is not None
|
|
||||||
assert len(tool_calls) == 1
|
|
||||||
assert updated_finish_reason["type"] == "tool_calls"
|
|
||||||
|
|
||||||
|
|
||||||
class TestMultimodalContent:
|
|
||||||
"""Test multimodal content handling from adapter.py"""
|
|
||||||
|
|
||||||
def test_multimodal_request_with_images(self, serving_chat):
|
|
||||||
"""Test request with image content"""
|
|
||||||
request = ChatCompletionRequest(
|
|
||||||
model="test-model",
|
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": [
|
"content": [
|
||||||
{"type": "text", "text": "What's in this image?"},
|
{"type": "text", "text": "What's in the image?"},
|
||||||
{
|
{
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
"image_url": {"url": "data:image/jpeg;base64,..."},
|
"image_url": {"url": "data:image/jpeg;base64,"},
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set multimodal mode
|
with patch.object(
|
||||||
serving_chat.tokenizer_manager.model_config.is_multimodal = True
|
self.chat,
|
||||||
|
"_apply_jinja_template",
|
||||||
|
return_value=("prompt", [1, 2], ["img"], None, [], []),
|
||||||
|
), patch.object(
|
||||||
|
self.chat,
|
||||||
|
"_apply_conversation_template",
|
||||||
|
return_value=("prompt", ["img"], None, [], []),
|
||||||
|
):
|
||||||
|
out = self.chat._process_messages(req, True)
|
||||||
|
_, _, image_data, *_ = out
|
||||||
|
self.assertEqual(image_data, ["img"])
|
||||||
|
|
||||||
with patch.object(serving_chat, "_apply_jinja_template") as mock_apply:
|
# ------------- template handling -------------
|
||||||
mock_apply.return_value = (
|
def test_jinja_template_processing(self):
|
||||||
"prompt",
|
req = ChatCompletionRequest(
|
||||||
[1, 2, 3],
|
model="x", messages=[{"role": "user", "content": "Hello"}]
|
||||||
["image_data"],
|
|
||||||
None,
|
|
||||||
[],
|
|
||||||
[],
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch.object(
|
|
||||||
serving_chat, "_apply_conversation_template"
|
|
||||||
) as mock_conv:
|
|
||||||
mock_conv.return_value = ("prompt", ["image_data"], None, [], [])
|
|
||||||
|
|
||||||
(
|
|
||||||
prompt,
|
|
||||||
prompt_ids,
|
|
||||||
image_data,
|
|
||||||
audio_data,
|
|
||||||
modalities,
|
|
||||||
stop,
|
|
||||||
tool_call_constraint,
|
|
||||||
) = serving_chat._process_messages(request, True)
|
|
||||||
|
|
||||||
assert image_data == ["image_data"]
|
|
||||||
assert prompt == "prompt"
|
|
||||||
|
|
||||||
def test_multimodal_request_with_audio(self, serving_chat):
|
|
||||||
"""Test request with audio content"""
|
|
||||||
request = ChatCompletionRequest(
|
|
||||||
model="test-model",
|
|
||||||
messages=[
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{"type": "text", "text": "Transcribe this audio"},
|
|
||||||
{
|
|
||||||
"type": "audio_url",
|
|
||||||
"audio_url": {"url": "data:audio/wav;base64,UklGR..."},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
self.tm.chat_template_name = None
|
||||||
|
self.tm.tokenizer.chat_template = "<jinja>"
|
||||||
|
|
||||||
serving_chat.tokenizer_manager.model_config.is_multimodal = True
|
with patch.object(
|
||||||
|
self.chat,
|
||||||
|
"_apply_jinja_template",
|
||||||
|
return_value=("processed", [1], None, None, [], ["</s>"]),
|
||||||
|
), patch("builtins.hasattr", return_value=True):
|
||||||
|
prompt, prompt_ids, *_ = self.chat._process_messages(req, False)
|
||||||
|
self.assertEqual(prompt, "processed")
|
||||||
|
self.assertEqual(prompt_ids, [1])
|
||||||
|
|
||||||
with patch.object(serving_chat, "_apply_jinja_template") as mock_apply:
|
# ------------- sampling-params -------------
|
||||||
mock_apply.return_value = (
|
def test_sampling_param_build(self):
|
||||||
"prompt",
|
req = ChatCompletionRequest(
|
||||||
[1, 2, 3],
|
model="x",
|
||||||
None,
|
messages=[{"role": "user", "content": "Hi"}],
|
||||||
["audio_data"],
|
|
||||||
["audio"],
|
|
||||||
[],
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch.object(
|
|
||||||
serving_chat, "_apply_conversation_template"
|
|
||||||
) as mock_conv:
|
|
||||||
mock_conv.return_value = ("prompt", None, ["audio_data"], ["audio"], [])
|
|
||||||
|
|
||||||
(
|
|
||||||
prompt,
|
|
||||||
prompt_ids,
|
|
||||||
image_data,
|
|
||||||
audio_data,
|
|
||||||
modalities,
|
|
||||||
stop,
|
|
||||||
tool_call_constraint,
|
|
||||||
) = serving_chat._process_messages(request, True)
|
|
||||||
|
|
||||||
assert audio_data == ["audio_data"]
|
|
||||||
assert modalities == ["audio"]
|
|
||||||
|
|
||||||
|
|
||||||
class TestTemplateHandling:
|
|
||||||
"""Test chat template handling from adapter.py"""
|
|
||||||
|
|
||||||
def test_jinja_template_processing(self, serving_chat):
|
|
||||||
"""Test Jinja template processing"""
|
|
||||||
request = ChatCompletionRequest(
|
|
||||||
model="test-model", messages=[{"role": "user", "content": "Hello"}]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock the template attribute directly
|
|
||||||
serving_chat.tokenizer_manager.chat_template_name = None
|
|
||||||
serving_chat.tokenizer_manager.tokenizer.chat_template = "<jinja_template>"
|
|
||||||
|
|
||||||
with patch.object(serving_chat, "_apply_jinja_template") as mock_apply:
|
|
||||||
mock_apply.return_value = (
|
|
||||||
"processed_prompt",
|
|
||||||
[1, 2, 3],
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
[],
|
|
||||||
["</s>"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock hasattr to simulate the None check
|
|
||||||
with patch("builtins.hasattr") as mock_hasattr:
|
|
||||||
mock_hasattr.return_value = True
|
|
||||||
|
|
||||||
(
|
|
||||||
prompt,
|
|
||||||
prompt_ids,
|
|
||||||
image_data,
|
|
||||||
audio_data,
|
|
||||||
modalities,
|
|
||||||
stop,
|
|
||||||
tool_call_constraint,
|
|
||||||
) = serving_chat._process_messages(request, False)
|
|
||||||
|
|
||||||
assert prompt == "processed_prompt"
|
|
||||||
assert prompt_ids == [1, 2, 3]
|
|
||||||
|
|
||||||
def test_conversation_template_processing(self, serving_chat):
|
|
||||||
"""Test conversation template processing"""
|
|
||||||
request = ChatCompletionRequest(
|
|
||||||
model="test-model", messages=[{"role": "user", "content": "Hello"}]
|
|
||||||
)
|
|
||||||
|
|
||||||
serving_chat.tokenizer_manager.chat_template_name = "llama-3"
|
|
||||||
|
|
||||||
with patch.object(serving_chat, "_apply_conversation_template") as mock_apply:
|
|
||||||
mock_apply.return_value = ("conv_prompt", None, None, [], ["</s>"])
|
|
||||||
|
|
||||||
(
|
|
||||||
prompt,
|
|
||||||
prompt_ids,
|
|
||||||
image_data,
|
|
||||||
audio_data,
|
|
||||||
modalities,
|
|
||||||
stop,
|
|
||||||
tool_call_constraint,
|
|
||||||
) = serving_chat._process_messages(request, False)
|
|
||||||
|
|
||||||
assert prompt == "conv_prompt"
|
|
||||||
assert stop == ["</s>"]
|
|
||||||
|
|
||||||
def test_continue_final_message(self, serving_chat):
|
|
||||||
"""Test continue_final_message functionality"""
|
|
||||||
request = ChatCompletionRequest(
|
|
||||||
model="test-model",
|
|
||||||
messages=[
|
|
||||||
{"role": "user", "content": "Hello"},
|
|
||||||
{"role": "assistant", "content": "Hi there"},
|
|
||||||
],
|
|
||||||
continue_final_message=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch.object(serving_chat, "_apply_conversation_template") as mock_apply:
|
|
||||||
mock_apply.return_value = ("Hi there", None, None, [], ["</s>"])
|
|
||||||
|
|
||||||
(
|
|
||||||
prompt,
|
|
||||||
prompt_ids,
|
|
||||||
image_data,
|
|
||||||
audio_data,
|
|
||||||
modalities,
|
|
||||||
stop,
|
|
||||||
tool_call_constraint,
|
|
||||||
) = serving_chat._process_messages(request, False)
|
|
||||||
|
|
||||||
# Should handle continue_final_message properly
|
|
||||||
assert prompt == "Hi there"
|
|
||||||
|
|
||||||
|
|
||||||
class TestReasoningContent:
|
|
||||||
"""Test reasoning content separation from adapter.py"""
|
|
||||||
|
|
||||||
def test_reasoning_content_request(self, serving_chat):
|
|
||||||
"""Test request with reasoning content separation"""
|
|
||||||
request = ChatCompletionRequest(
|
|
||||||
model="test-model",
|
|
||||||
messages=[{"role": "user", "content": "Solve this math problem"}],
|
|
||||||
separate_reasoning=True,
|
|
||||||
stream_reasoning=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch.object(serving_chat, "_process_messages") as mock_process:
|
|
||||||
mock_process.return_value = (
|
|
||||||
"Test prompt",
|
|
||||||
[1, 2, 3],
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
[],
|
|
||||||
["</s>"],
|
|
||||||
None, # tool_call_constraint
|
|
||||||
)
|
|
||||||
|
|
||||||
adapted_request, _ = serving_chat._convert_to_internal_request(
|
|
||||||
[request], ["test-id"]
|
|
||||||
)
|
|
||||||
|
|
||||||
assert adapted_request.rid == "test-id"
|
|
||||||
assert request.separate_reasoning == True
|
|
||||||
|
|
||||||
def test_reasoning_content_response(self, serving_chat):
|
|
||||||
"""Test reasoning content in response"""
|
|
||||||
mock_ret_item = {
|
|
||||||
"text": "<thinking>This is reasoning</thinking>Answer: 42",
|
|
||||||
"meta_info": {
|
|
||||||
"output_token_logprobs": [],
|
|
||||||
"output_top_logprobs": None,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
# Mock ReasoningParser
|
|
||||||
with patch(
|
|
||||||
"sglang.srt.entrypoints.openai.serving_chat.ReasoningParser"
|
|
||||||
) as mock_parser_class:
|
|
||||||
mock_parser = Mock()
|
|
||||||
mock_parser.parse_non_stream.return_value = (
|
|
||||||
"This is reasoning",
|
|
||||||
"Answer: 42",
|
|
||||||
)
|
|
||||||
mock_parser_class.return_value = mock_parser
|
|
||||||
|
|
||||||
choice_logprobs = None
|
|
||||||
reasoning_text = None
|
|
||||||
text = mock_ret_item["text"]
|
|
||||||
|
|
||||||
# Simulate reasoning processing
|
|
||||||
enable_thinking = True
|
|
||||||
if enable_thinking:
|
|
||||||
parser = mock_parser_class(model_type="test", stream_reasoning=False)
|
|
||||||
reasoning_text, text = parser.parse_non_stream(text)
|
|
||||||
|
|
||||||
assert reasoning_text == "This is reasoning"
|
|
||||||
assert text == "Answer: 42"
|
|
||||||
|
|
||||||
|
|
||||||
class TestSamplingParams:
|
|
||||||
"""Test sampling parameter handling from adapter.py"""
|
|
||||||
|
|
||||||
def test_all_sampling_parameters(self, serving_chat):
|
|
||||||
"""Test all sampling parameters are properly handled"""
|
|
||||||
request = ChatCompletionRequest(
|
|
||||||
model="test-model",
|
|
||||||
messages=[{"role": "user", "content": "Hello"}],
|
|
||||||
temperature=0.8,
|
temperature=0.8,
|
||||||
max_tokens=150,
|
max_tokens=150,
|
||||||
max_completion_tokens=200,
|
|
||||||
min_tokens=5,
|
min_tokens=5,
|
||||||
top_p=0.9,
|
top_p=0.9,
|
||||||
top_k=50,
|
stop=["</s>"],
|
||||||
min_p=0.1,
|
|
||||||
presence_penalty=0.1,
|
|
||||||
frequency_penalty=0.2,
|
|
||||||
repetition_penalty=1.1,
|
|
||||||
stop=["<|endoftext|>"],
|
|
||||||
stop_token_ids=[13, 14],
|
|
||||||
regex=r"\d+",
|
|
||||||
ebnf="<expr> ::= <number>",
|
|
||||||
n=2,
|
|
||||||
no_stop_trim=True,
|
|
||||||
ignore_eos=True,
|
|
||||||
skip_special_tokens=False,
|
|
||||||
logit_bias={"1": 0.5, "2": -0.3},
|
|
||||||
)
|
)
|
||||||
|
with patch.object(
|
||||||
|
self.chat,
|
||||||
|
"_process_messages",
|
||||||
|
return_value=("Prompt", [1], None, None, [], ["</s>"], None),
|
||||||
|
):
|
||||||
|
params = self.chat._build_sampling_params(req, ["</s>"], None)
|
||||||
|
self.assertEqual(params["temperature"], 0.8)
|
||||||
|
self.assertEqual(params["max_new_tokens"], 150)
|
||||||
|
self.assertEqual(params["min_new_tokens"], 5)
|
||||||
|
self.assertEqual(params["stop"], ["</s>"])
|
||||||
|
|
||||||
with patch.object(serving_chat, "_process_messages") as mock_process:
|
|
||||||
mock_process.return_value = (
|
|
||||||
"Test prompt",
|
|
||||||
[1, 2, 3],
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
[],
|
|
||||||
["</s>"],
|
|
||||||
None, # tool_call_constraint
|
|
||||||
)
|
|
||||||
|
|
||||||
sampling_params = serving_chat._build_sampling_params(
|
if __name__ == "__main__":
|
||||||
request, ["</s>"], None
|
unittest.main(verbosity=2)
|
||||||
)
|
|
||||||
|
|
||||||
# Verify all parameters
|
|
||||||
assert sampling_params["temperature"] == 0.8
|
|
||||||
assert sampling_params["max_new_tokens"] == 150
|
|
||||||
assert sampling_params["min_new_tokens"] == 5
|
|
||||||
assert sampling_params["top_p"] == 0.9
|
|
||||||
assert sampling_params["top_k"] == 50
|
|
||||||
assert sampling_params["min_p"] == 0.1
|
|
||||||
assert sampling_params["presence_penalty"] == 0.1
|
|
||||||
assert sampling_params["frequency_penalty"] == 0.2
|
|
||||||
assert sampling_params["repetition_penalty"] == 1.1
|
|
||||||
assert sampling_params["stop"] == ["</s>"]
|
|
||||||
assert sampling_params["logit_bias"] == {"1": 0.5, "2": -0.3}
|
|
||||||
|
|
||||||
def test_response_format_json_schema(self, serving_chat):
|
|
||||||
"""Test response format with JSON schema"""
|
|
||||||
request = ChatCompletionRequest(
|
|
||||||
model="test-model",
|
|
||||||
messages=[{"role": "user", "content": "Generate JSON"}],
|
|
||||||
response_format={
|
|
||||||
"type": "json_schema",
|
|
||||||
"json_schema": {
|
|
||||||
"name": "response",
|
|
||||||
"schema": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"answer": {"type": "string"}},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch.object(serving_chat, "_process_messages") as mock_process:
|
|
||||||
mock_process.return_value = (
|
|
||||||
"Test prompt",
|
|
||||||
[1, 2, 3],
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
[],
|
|
||||||
["</s>"],
|
|
||||||
None, # tool_call_constraint
|
|
||||||
)
|
|
||||||
|
|
||||||
sampling_params = serving_chat._build_sampling_params(
|
|
||||||
request, ["</s>"], None
|
|
||||||
)
|
|
||||||
|
|
||||||
assert "json_schema" in sampling_params
|
|
||||||
assert '"type": "object"' in sampling_params["json_schema"]
|
|
||||||
|
|
||||||
def test_response_format_json_object(self, serving_chat):
|
|
||||||
"""Test response format with JSON object"""
|
|
||||||
request = ChatCompletionRequest(
|
|
||||||
model="test-model",
|
|
||||||
messages=[{"role": "user", "content": "Generate JSON"}],
|
|
||||||
response_format={"type": "json_object"},
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch.object(serving_chat, "_process_messages") as mock_process:
|
|
||||||
mock_process.return_value = (
|
|
||||||
"Test prompt",
|
|
||||||
[1, 2, 3],
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
[],
|
|
||||||
["</s>"],
|
|
||||||
None, # tool_call_constraint
|
|
||||||
)
|
|
||||||
|
|
||||||
sampling_params = serving_chat._build_sampling_params(
|
|
||||||
request, ["</s>"], None
|
|
||||||
)
|
|
||||||
|
|
||||||
assert sampling_params["json_schema"] == '{"type": "object"}'
|
|
||||||
|
|||||||
@@ -1,176 +1,101 @@
|
|||||||
"""
|
"""
|
||||||
Tests for the refactored completions serving handler
|
Unit-tests for the refactored completions-serving handler (no pytest).
|
||||||
|
Run with:
|
||||||
|
python -m unittest tests.test_serving_completions_unit -v
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
from unittest.mock import AsyncMock, Mock, patch
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
import pytest
|
from sglang.srt.entrypoints.openai.protocol import CompletionRequest
|
||||||
|
|
||||||
from sglang.srt.entrypoints.openai.protocol import (
|
|
||||||
CompletionRequest,
|
|
||||||
CompletionResponse,
|
|
||||||
CompletionResponseChoice,
|
|
||||||
CompletionStreamResponse,
|
|
||||||
ErrorResponse,
|
|
||||||
)
|
|
||||||
from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompletion
|
from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompletion
|
||||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
|
||||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
class ServingCompletionTestCase(unittest.TestCase):
|
||||||
def mock_tokenizer_manager():
|
"""Bundle all prompt/echo tests in one TestCase."""
|
||||||
"""Create a mock tokenizer manager"""
|
|
||||||
manager = Mock(spec=TokenizerManager)
|
|
||||||
|
|
||||||
# Mock tokenizer
|
# ---------- shared test fixtures ----------
|
||||||
manager.tokenizer = Mock()
|
def setUp(self):
|
||||||
manager.tokenizer.encode = Mock(return_value=[1, 2, 3, 4])
|
# build the mock TokenizerManager once for every test
|
||||||
manager.tokenizer.decode = Mock(return_value="decoded text")
|
tm = Mock(spec=TokenizerManager)
|
||||||
manager.tokenizer.bos_token_id = 1
|
|
||||||
|
|
||||||
# Mock model config
|
tm.tokenizer = Mock()
|
||||||
manager.model_config = Mock()
|
tm.tokenizer.encode.return_value = [1, 2, 3, 4]
|
||||||
manager.model_config.is_multimodal = False
|
tm.tokenizer.decode.return_value = "decoded text"
|
||||||
|
tm.tokenizer.bos_token_id = 1
|
||||||
|
|
||||||
# Mock server args
|
tm.model_config = Mock(is_multimodal=False)
|
||||||
manager.server_args = Mock()
|
tm.server_args = Mock(enable_cache_report=False)
|
||||||
manager.server_args.enable_cache_report = False
|
|
||||||
|
|
||||||
# Mock generation
|
tm.generate_request = AsyncMock()
|
||||||
manager.generate_request = AsyncMock()
|
tm.create_abort_task = Mock()
|
||||||
manager.create_abort_task = Mock(return_value=None)
|
|
||||||
|
|
||||||
return manager
|
self.sc = OpenAIServingCompletion(tm)
|
||||||
|
|
||||||
|
# ---------- prompt-handling ----------
|
||||||
|
def test_single_string_prompt(self):
|
||||||
|
req = CompletionRequest(model="x", prompt="Hello world", max_tokens=100)
|
||||||
|
internal, _ = self.sc._convert_to_internal_request([req], ["id"])
|
||||||
|
self.assertEqual(internal.text, "Hello world")
|
||||||
|
|
||||||
@pytest.fixture
|
def test_single_token_ids_prompt(self):
|
||||||
def serving_completion(mock_tokenizer_manager):
|
req = CompletionRequest(model="x", prompt=[1, 2, 3, 4], max_tokens=100)
|
||||||
"""Create a OpenAIServingCompletion instance"""
|
internal, _ = self.sc._convert_to_internal_request([req], ["id"])
|
||||||
return OpenAIServingCompletion(mock_tokenizer_manager)
|
self.assertEqual(internal.input_ids, [1, 2, 3, 4])
|
||||||
|
|
||||||
|
def test_completion_template_handling(self):
|
||||||
class TestPromptHandling:
|
req = CompletionRequest(
|
||||||
"""Test different prompt types and formats from adapter.py"""
|
model="x", prompt="def f():", suffix="return 1", max_tokens=100
|
||||||
|
|
||||||
def test_single_string_prompt(self, serving_completion):
|
|
||||||
"""Test handling single string prompt"""
|
|
||||||
request = CompletionRequest(
|
|
||||||
model="test-model", prompt="Hello world", max_tokens=100
|
|
||||||
)
|
)
|
||||||
|
|
||||||
adapted_request, _ = serving_completion._convert_to_internal_request(
|
|
||||||
[request], ["test-id"]
|
|
||||||
)
|
|
||||||
|
|
||||||
assert adapted_request.text == "Hello world"
|
|
||||||
|
|
||||||
def test_single_token_ids_prompt(self, serving_completion):
|
|
||||||
"""Test handling single token IDs prompt"""
|
|
||||||
request = CompletionRequest(
|
|
||||||
model="test-model", prompt=[1, 2, 3, 4], max_tokens=100
|
|
||||||
)
|
|
||||||
|
|
||||||
adapted_request, _ = serving_completion._convert_to_internal_request(
|
|
||||||
[request], ["test-id"]
|
|
||||||
)
|
|
||||||
|
|
||||||
assert adapted_request.input_ids == [1, 2, 3, 4]
|
|
||||||
|
|
||||||
def test_completion_template_handling(self, serving_completion):
|
|
||||||
"""Test completion template processing"""
|
|
||||||
request = CompletionRequest(
|
|
||||||
model="test-model",
|
|
||||||
prompt="def hello():",
|
|
||||||
suffix="return 'world'",
|
|
||||||
max_tokens=100,
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"sglang.srt.entrypoints.openai.serving_completions.is_completion_template_defined",
|
"sglang.srt.entrypoints.openai.serving_completions.is_completion_template_defined",
|
||||||
return_value=True,
|
return_value=True,
|
||||||
|
), patch(
|
||||||
|
"sglang.srt.entrypoints.openai.serving_completions.generate_completion_prompt_from_request",
|
||||||
|
return_value="processed_prompt",
|
||||||
):
|
):
|
||||||
with patch(
|
internal, _ = self.sc._convert_to_internal_request([req], ["id"])
|
||||||
"sglang.srt.entrypoints.openai.serving_completions.generate_completion_prompt_from_request",
|
self.assertEqual(internal.text, "processed_prompt")
|
||||||
return_value="processed_prompt",
|
|
||||||
):
|
|
||||||
adapted_request, _ = serving_completion._convert_to_internal_request(
|
|
||||||
[request], ["test-id"]
|
|
||||||
)
|
|
||||||
|
|
||||||
assert adapted_request.text == "processed_prompt"
|
# ---------- echo-handling ----------
|
||||||
|
def test_echo_with_string_prompt_streaming(self):
|
||||||
|
req = CompletionRequest(model="x", prompt="Hello", max_tokens=1, echo=True)
|
||||||
|
self.assertEqual(self.sc._get_echo_text(req, 0), "Hello")
|
||||||
|
|
||||||
|
def test_echo_with_list_of_strings_streaming(self):
|
||||||
class TestEchoHandling:
|
req = CompletionRequest(
|
||||||
"""Test echo functionality from adapter.py"""
|
model="x", prompt=["A", "B"], max_tokens=1, echo=True, n=1
|
||||||
|
|
||||||
def test_echo_with_string_prompt_streaming(self, serving_completion):
|
|
||||||
"""Test echo handling with string prompt in streaming"""
|
|
||||||
request = CompletionRequest(
|
|
||||||
model="test-model", prompt="Hello", max_tokens=100, echo=True
|
|
||||||
)
|
)
|
||||||
|
self.assertEqual(self.sc._get_echo_text(req, 0), "A")
|
||||||
|
self.assertEqual(self.sc._get_echo_text(req, 1), "B")
|
||||||
|
|
||||||
# Test _get_echo_text method
|
def test_echo_with_token_ids_streaming(self):
|
||||||
echo_text = serving_completion._get_echo_text(request, 0)
|
req = CompletionRequest(model="x", prompt=[1, 2, 3], max_tokens=1, echo=True)
|
||||||
assert echo_text == "Hello"
|
self.sc.tokenizer_manager.tokenizer.decode.return_value = "decoded_prompt"
|
||||||
|
self.assertEqual(self.sc._get_echo_text(req, 0), "decoded_prompt")
|
||||||
|
|
||||||
def test_echo_with_list_of_strings_streaming(self, serving_completion):
|
def test_echo_with_multiple_token_ids_streaming(self):
|
||||||
"""Test echo handling with list of strings in streaming"""
|
req = CompletionRequest(
|
||||||
request = CompletionRequest(
|
model="x", prompt=[[1, 2], [3, 4]], max_tokens=1, echo=True, n=1
|
||||||
model="test-model",
|
|
||||||
prompt=["Hello", "World"],
|
|
||||||
max_tokens=100,
|
|
||||||
echo=True,
|
|
||||||
n=1,
|
|
||||||
)
|
)
|
||||||
|
self.sc.tokenizer_manager.tokenizer.decode.return_value = "decoded"
|
||||||
|
self.assertEqual(self.sc._get_echo_text(req, 0), "decoded")
|
||||||
|
|
||||||
echo_text = serving_completion._get_echo_text(request, 0)
|
def test_prepare_echo_prompts_non_streaming(self):
|
||||||
assert echo_text == "Hello"
|
# single string
|
||||||
|
req = CompletionRequest(model="x", prompt="Hi", echo=True)
|
||||||
|
self.assertEqual(self.sc._prepare_echo_prompts(req), ["Hi"])
|
||||||
|
|
||||||
echo_text = serving_completion._get_echo_text(request, 1)
|
# list of strings
|
||||||
assert echo_text == "World"
|
req = CompletionRequest(model="x", prompt=["Hi", "Yo"], echo=True)
|
||||||
|
self.assertEqual(self.sc._prepare_echo_prompts(req), ["Hi", "Yo"])
|
||||||
|
|
||||||
def test_echo_with_token_ids_streaming(self, serving_completion):
|
# token IDs
|
||||||
"""Test echo handling with token IDs in streaming"""
|
req = CompletionRequest(model="x", prompt=[1, 2, 3], echo=True)
|
||||||
request = CompletionRequest(
|
self.sc.tokenizer_manager.tokenizer.decode.return_value = "decoded"
|
||||||
model="test-model", prompt=[1, 2, 3], max_tokens=100, echo=True
|
self.assertEqual(self.sc._prepare_echo_prompts(req), ["decoded"])
|
||||||
)
|
|
||||||
|
|
||||||
serving_completion.tokenizer_manager.tokenizer.decode.return_value = (
|
|
||||||
"decoded_prompt"
|
|
||||||
)
|
|
||||||
echo_text = serving_completion._get_echo_text(request, 0)
|
|
||||||
assert echo_text == "decoded_prompt"
|
|
||||||
|
|
||||||
def test_echo_with_multiple_token_ids_streaming(self, serving_completion):
|
if __name__ == "__main__":
|
||||||
"""Test echo handling with multiple token ID prompts in streaming"""
|
unittest.main(verbosity=2)
|
||||||
request = CompletionRequest(
|
|
||||||
model="test-model", prompt=[[1, 2], [3, 4]], max_tokens=100, echo=True, n=1
|
|
||||||
)
|
|
||||||
|
|
||||||
serving_completion.tokenizer_manager.tokenizer.decode.return_value = "decoded"
|
|
||||||
echo_text = serving_completion._get_echo_text(request, 0)
|
|
||||||
assert echo_text == "decoded"
|
|
||||||
|
|
||||||
def test_prepare_echo_prompts_non_streaming(self, serving_completion):
|
|
||||||
"""Test prepare echo prompts for non-streaming response"""
|
|
||||||
# Test with single string
|
|
||||||
request = CompletionRequest(model="test-model", prompt="Hello", echo=True)
|
|
||||||
|
|
||||||
echo_prompts = serving_completion._prepare_echo_prompts(request)
|
|
||||||
assert echo_prompts == ["Hello"]
|
|
||||||
|
|
||||||
# Test with list of strings
|
|
||||||
request = CompletionRequest(
|
|
||||||
model="test-model", prompt=["Hello", "World"], echo=True
|
|
||||||
)
|
|
||||||
|
|
||||||
echo_prompts = serving_completion._prepare_echo_prompts(request)
|
|
||||||
assert echo_prompts == ["Hello", "World"]
|
|
||||||
|
|
||||||
# Test with token IDs
|
|
||||||
request = CompletionRequest(model="test-model", prompt=[1, 2, 3], echo=True)
|
|
||||||
|
|
||||||
serving_completion.tokenizer_manager.tokenizer.decode.return_value = "decoded"
|
|
||||||
echo_prompts = serving_completion._prepare_echo_prompts(request)
|
|
||||||
assert echo_prompts == ["decoded"]
|
|
||||||
|
|||||||
@@ -8,11 +8,11 @@ with the original adapter.py functionality and follows OpenAI API specifications
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
import unittest
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
from unittest.mock import AsyncMock, Mock, patch
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
import pytest
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from fastapi.responses import ORJSONResponse
|
from fastapi.responses import ORJSONResponse
|
||||||
from pydantic_core import ValidationError
|
from pydantic_core import ValidationError
|
||||||
@@ -30,7 +30,7 @@ from sglang.srt.managers.io_struct import EmbeddingReqInput
|
|||||||
|
|
||||||
|
|
||||||
# Mock TokenizerManager for embedding tests
|
# Mock TokenizerManager for embedding tests
|
||||||
class MockTokenizerManager:
|
class _MockTokenizerManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.model_config = Mock()
|
self.model_config = Mock()
|
||||||
self.model_config.is_multimodal = False
|
self.model_config.is_multimodal = False
|
||||||
@@ -58,141 +58,98 @@ class MockTokenizerManager:
|
|||||||
self.generate_request = Mock(return_value=mock_generate_embedding())
|
self.generate_request = Mock(return_value=mock_generate_embedding())
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
class ServingEmbeddingTestCase(unittest.TestCase):
|
||||||
def mock_tokenizer_manager():
|
def setUp(self):
|
||||||
"""Create a mock tokenizer manager for testing."""
|
"""Set up test fixtures."""
|
||||||
return MockTokenizerManager()
|
self.tokenizer_manager = _MockTokenizerManager()
|
||||||
|
self.serving_embedding = OpenAIServingEmbedding(self.tokenizer_manager)
|
||||||
|
|
||||||
|
self.request = Mock(spec=Request)
|
||||||
|
self.request.headers = {}
|
||||||
|
|
||||||
@pytest.fixture
|
self.basic_req = EmbeddingRequest(
|
||||||
def serving_embedding(mock_tokenizer_manager):
|
model="test-model",
|
||||||
"""Create an OpenAIServingEmbedding instance for testing."""
|
input="Hello, how are you?",
|
||||||
return OpenAIServingEmbedding(mock_tokenizer_manager)
|
encoding_format="float",
|
||||||
|
)
|
||||||
|
self.list_req = EmbeddingRequest(
|
||||||
|
model="test-model",
|
||||||
|
input=["Hello, how are you?", "I am fine, thank you!"],
|
||||||
|
encoding_format="float",
|
||||||
|
)
|
||||||
|
self.multimodal_req = EmbeddingRequest(
|
||||||
|
model="test-model",
|
||||||
|
input=[
|
||||||
|
MultimodalEmbeddingInput(text="Hello", image="base64_image_data"),
|
||||||
|
MultimodalEmbeddingInput(text="World", image=None),
|
||||||
|
],
|
||||||
|
encoding_format="float",
|
||||||
|
)
|
||||||
|
self.token_ids_req = EmbeddingRequest(
|
||||||
|
model="test-model",
|
||||||
|
input=[1, 2, 3, 4, 5],
|
||||||
|
encoding_format="float",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_convert_single_string_request(self):
|
||||||
@pytest.fixture
|
|
||||||
def mock_request():
|
|
||||||
"""Create a mock FastAPI request."""
|
|
||||||
request = Mock(spec=Request)
|
|
||||||
request.headers = {}
|
|
||||||
return request
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def basic_embedding_request():
|
|
||||||
"""Create a basic embedding request."""
|
|
||||||
return EmbeddingRequest(
|
|
||||||
model="test-model",
|
|
||||||
input="Hello, how are you?",
|
|
||||||
encoding_format="float",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def list_embedding_request():
|
|
||||||
"""Create an embedding request with list input."""
|
|
||||||
return EmbeddingRequest(
|
|
||||||
model="test-model",
|
|
||||||
input=["Hello, how are you?", "I am fine, thank you!"],
|
|
||||||
encoding_format="float",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def multimodal_embedding_request():
|
|
||||||
"""Create a multimodal embedding request."""
|
|
||||||
return EmbeddingRequest(
|
|
||||||
model="test-model",
|
|
||||||
input=[
|
|
||||||
MultimodalEmbeddingInput(text="Hello", image="base64_image_data"),
|
|
||||||
MultimodalEmbeddingInput(text="World", image=None),
|
|
||||||
],
|
|
||||||
encoding_format="float",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def token_ids_embedding_request():
|
|
||||||
"""Create an embedding request with token IDs."""
|
|
||||||
return EmbeddingRequest(
|
|
||||||
model="test-model",
|
|
||||||
input=[1, 2, 3, 4, 5],
|
|
||||||
encoding_format="float",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestOpenAIServingEmbeddingConversion:
|
|
||||||
"""Test request conversion methods."""
|
|
||||||
|
|
||||||
def test_convert_single_string_request(
|
|
||||||
self, serving_embedding, basic_embedding_request
|
|
||||||
):
|
|
||||||
"""Test converting single string request to internal format."""
|
"""Test converting single string request to internal format."""
|
||||||
adapted_request, processed_request = (
|
adapted_request, processed_request = (
|
||||||
serving_embedding._convert_to_internal_request(
|
self.serving_embedding._convert_to_internal_request(
|
||||||
[basic_embedding_request], ["test-id"]
|
[self.basic_req], ["test-id"]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(adapted_request, EmbeddingReqInput)
|
self.assertIsInstance(adapted_request, EmbeddingReqInput)
|
||||||
assert adapted_request.text == "Hello, how are you?"
|
self.assertEqual(adapted_request.text, "Hello, how are you?")
|
||||||
assert adapted_request.rid == "test-id"
|
self.assertEqual(adapted_request.rid, "test-id")
|
||||||
assert processed_request == basic_embedding_request
|
self.assertEqual(processed_request, self.basic_req)
|
||||||
|
|
||||||
def test_convert_list_string_request(
|
def test_convert_list_string_request(self):
|
||||||
self, serving_embedding, list_embedding_request
|
|
||||||
):
|
|
||||||
"""Test converting list of strings request to internal format."""
|
"""Test converting list of strings request to internal format."""
|
||||||
adapted_request, processed_request = (
|
adapted_request, processed_request = (
|
||||||
serving_embedding._convert_to_internal_request(
|
self.serving_embedding._convert_to_internal_request(
|
||||||
[list_embedding_request], ["test-id"]
|
[self.list_req], ["test-id"]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(adapted_request, EmbeddingReqInput)
|
self.assertIsInstance(adapted_request, EmbeddingReqInput)
|
||||||
assert adapted_request.text == ["Hello, how are you?", "I am fine, thank you!"]
|
self.assertEqual(
|
||||||
assert adapted_request.rid == "test-id"
|
adapted_request.text, ["Hello, how are you?", "I am fine, thank you!"]
|
||||||
assert processed_request == list_embedding_request
|
)
|
||||||
|
self.assertEqual(adapted_request.rid, "test-id")
|
||||||
|
self.assertEqual(processed_request, self.list_req)
|
||||||
|
|
||||||
def test_convert_token_ids_request(
|
def test_convert_token_ids_request(self):
|
||||||
self, serving_embedding, token_ids_embedding_request
|
|
||||||
):
|
|
||||||
"""Test converting token IDs request to internal format."""
|
"""Test converting token IDs request to internal format."""
|
||||||
adapted_request, processed_request = (
|
adapted_request, processed_request = (
|
||||||
serving_embedding._convert_to_internal_request(
|
self.serving_embedding._convert_to_internal_request(
|
||||||
[token_ids_embedding_request], ["test-id"]
|
[self.token_ids_req], ["test-id"]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(adapted_request, EmbeddingReqInput)
|
self.assertIsInstance(adapted_request, EmbeddingReqInput)
|
||||||
assert adapted_request.input_ids == [1, 2, 3, 4, 5]
|
self.assertEqual(adapted_request.input_ids, [1, 2, 3, 4, 5])
|
||||||
assert adapted_request.rid == "test-id"
|
self.assertEqual(adapted_request.rid, "test-id")
|
||||||
assert processed_request == token_ids_embedding_request
|
self.assertEqual(processed_request, self.token_ids_req)
|
||||||
|
|
||||||
def test_convert_multimodal_request(
|
def test_convert_multimodal_request(self):
|
||||||
self, serving_embedding, multimodal_embedding_request
|
|
||||||
):
|
|
||||||
"""Test converting multimodal request to internal format."""
|
"""Test converting multimodal request to internal format."""
|
||||||
adapted_request, processed_request = (
|
adapted_request, processed_request = (
|
||||||
serving_embedding._convert_to_internal_request(
|
self.serving_embedding._convert_to_internal_request(
|
||||||
[multimodal_embedding_request], ["test-id"]
|
[self.multimodal_req], ["test-id"]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(adapted_request, EmbeddingReqInput)
|
self.assertIsInstance(adapted_request, EmbeddingReqInput)
|
||||||
# Should extract text and images separately
|
# Should extract text and images separately
|
||||||
assert len(adapted_request.text) == 2
|
self.assertEqual(len(adapted_request.text), 2)
|
||||||
assert "Hello" in adapted_request.text
|
self.assertIn("Hello", adapted_request.text)
|
||||||
assert "World" in adapted_request.text
|
self.assertIn("World", adapted_request.text)
|
||||||
assert adapted_request.image_data[0] == "base64_image_data"
|
self.assertEqual(adapted_request.image_data[0], "base64_image_data")
|
||||||
assert adapted_request.image_data[1] is None
|
self.assertIsNone(adapted_request.image_data[1])
|
||||||
assert adapted_request.rid == "test-id"
|
self.assertEqual(adapted_request.rid, "test-id")
|
||||||
|
|
||||||
|
def test_build_single_embedding_response(self):
|
||||||
class TestEmbeddingResponseBuilding:
|
|
||||||
"""Test response building methods."""
|
|
||||||
|
|
||||||
def test_build_single_embedding_response(self, serving_embedding):
|
|
||||||
"""Test building response for single embedding."""
|
"""Test building response for single embedding."""
|
||||||
ret_data = [
|
ret_data = [
|
||||||
{
|
{
|
||||||
@@ -201,19 +158,21 @@ class TestEmbeddingResponseBuilding:
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
response = serving_embedding._build_embedding_response(ret_data, "test-model")
|
response = self.serving_embedding._build_embedding_response(
|
||||||
|
ret_data, "test-model"
|
||||||
|
)
|
||||||
|
|
||||||
assert isinstance(response, EmbeddingResponse)
|
self.assertIsInstance(response, EmbeddingResponse)
|
||||||
assert response.model == "test-model"
|
self.assertEqual(response.model, "test-model")
|
||||||
assert len(response.data) == 1
|
self.assertEqual(len(response.data), 1)
|
||||||
assert response.data[0].embedding == [0.1, 0.2, 0.3, 0.4, 0.5]
|
self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3, 0.4, 0.5])
|
||||||
assert response.data[0].index == 0
|
self.assertEqual(response.data[0].index, 0)
|
||||||
assert response.data[0].object == "embedding"
|
self.assertEqual(response.data[0].object, "embedding")
|
||||||
assert response.usage.prompt_tokens == 5
|
self.assertEqual(response.usage.prompt_tokens, 5)
|
||||||
assert response.usage.total_tokens == 5
|
self.assertEqual(response.usage.total_tokens, 5)
|
||||||
assert response.usage.completion_tokens == 0
|
self.assertEqual(response.usage.completion_tokens, 0)
|
||||||
|
|
||||||
def test_build_multiple_embedding_response(self, serving_embedding):
|
def test_build_multiple_embedding_response(self):
|
||||||
"""Test building response for multiple embeddings."""
|
"""Test building response for multiple embeddings."""
|
||||||
ret_data = [
|
ret_data = [
|
||||||
{
|
{
|
||||||
@@ -226,25 +185,20 @@ class TestEmbeddingResponseBuilding:
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
response = serving_embedding._build_embedding_response(ret_data, "test-model")
|
response = self.serving_embedding._build_embedding_response(
|
||||||
|
ret_data, "test-model"
|
||||||
|
)
|
||||||
|
|
||||||
assert isinstance(response, EmbeddingResponse)
|
self.assertIsInstance(response, EmbeddingResponse)
|
||||||
assert len(response.data) == 2
|
self.assertEqual(len(response.data), 2)
|
||||||
assert response.data[0].embedding == [0.1, 0.2, 0.3]
|
self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3])
|
||||||
assert response.data[0].index == 0
|
self.assertEqual(response.data[0].index, 0)
|
||||||
assert response.data[1].embedding == [0.4, 0.5, 0.6]
|
self.assertEqual(response.data[1].embedding, [0.4, 0.5, 0.6])
|
||||||
assert response.data[1].index == 1
|
self.assertEqual(response.data[1].index, 1)
|
||||||
assert response.usage.prompt_tokens == 7 # 3 + 4
|
self.assertEqual(response.usage.prompt_tokens, 7) # 3 + 4
|
||||||
assert response.usage.total_tokens == 7
|
self.assertEqual(response.usage.total_tokens, 7)
|
||||||
|
|
||||||
|
async def test_handle_request_success(self):
|
||||||
@pytest.mark.asyncio
|
|
||||||
class TestOpenAIServingEmbeddingAsyncMethods:
|
|
||||||
"""Test async methods of OpenAIServingEmbedding."""
|
|
||||||
|
|
||||||
async def test_handle_request_success(
|
|
||||||
self, serving_embedding, basic_embedding_request, mock_request
|
|
||||||
):
|
|
||||||
"""Test successful embedding request handling."""
|
"""Test successful embedding request handling."""
|
||||||
|
|
||||||
# Mock the generate_request to return expected data
|
# Mock the generate_request to return expected data
|
||||||
@@ -254,32 +208,30 @@ class TestOpenAIServingEmbeddingAsyncMethods:
|
|||||||
"meta_info": {"prompt_tokens": 5},
|
"meta_info": {"prompt_tokens": 5},
|
||||||
}
|
}
|
||||||
|
|
||||||
serving_embedding.tokenizer_manager.generate_request = Mock(
|
self.serving_embedding.tokenizer_manager.generate_request = Mock(
|
||||||
return_value=mock_generate()
|
return_value=mock_generate()
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await serving_embedding.handle_request(
|
response = await self.serving_embedding.handle_request(
|
||||||
basic_embedding_request, mock_request
|
self.basic_req, self.request
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(response, EmbeddingResponse)
|
self.assertIsInstance(response, EmbeddingResponse)
|
||||||
assert len(response.data) == 1
|
self.assertEqual(len(response.data), 1)
|
||||||
assert response.data[0].embedding == [0.1, 0.2, 0.3, 0.4, 0.5]
|
self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3, 0.4, 0.5])
|
||||||
|
|
||||||
async def test_handle_request_validation_error(
|
async def test_handle_request_validation_error(self):
|
||||||
self, serving_embedding, mock_request
|
|
||||||
):
|
|
||||||
"""Test handling request with validation error."""
|
"""Test handling request with validation error."""
|
||||||
invalid_request = EmbeddingRequest(model="test-model", input="")
|
invalid_request = EmbeddingRequest(model="test-model", input="")
|
||||||
|
|
||||||
response = await serving_embedding.handle_request(invalid_request, mock_request)
|
response = await self.serving_embedding.handle_request(
|
||||||
|
invalid_request, self.request
|
||||||
|
)
|
||||||
|
|
||||||
assert isinstance(response, ORJSONResponse)
|
self.assertIsInstance(response, ORJSONResponse)
|
||||||
assert response.status_code == 400
|
self.assertEqual(response.status_code, 400)
|
||||||
|
|
||||||
async def test_handle_request_generation_error(
|
async def test_handle_request_generation_error(self):
|
||||||
self, serving_embedding, basic_embedding_request, mock_request
|
|
||||||
):
|
|
||||||
"""Test handling request with generation error."""
|
"""Test handling request with generation error."""
|
||||||
|
|
||||||
# Mock generate_request to raise an error
|
# Mock generate_request to raise an error
|
||||||
@@ -287,30 +239,32 @@ class TestOpenAIServingEmbeddingAsyncMethods:
|
|||||||
raise ValueError("Generation failed")
|
raise ValueError("Generation failed")
|
||||||
yield # This won't be reached but needed for async generator
|
yield # This won't be reached but needed for async generator
|
||||||
|
|
||||||
serving_embedding.tokenizer_manager.generate_request = Mock(
|
self.serving_embedding.tokenizer_manager.generate_request = Mock(
|
||||||
return_value=mock_generate_error()
|
return_value=mock_generate_error()
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await serving_embedding.handle_request(
|
response = await self.serving_embedding.handle_request(
|
||||||
basic_embedding_request, mock_request
|
self.basic_req, self.request
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(response, ORJSONResponse)
|
self.assertIsInstance(response, ORJSONResponse)
|
||||||
assert response.status_code == 400
|
self.assertEqual(response.status_code, 400)
|
||||||
|
|
||||||
async def test_handle_request_internal_error(
|
async def test_handle_request_internal_error(self):
|
||||||
self, serving_embedding, basic_embedding_request, mock_request
|
|
||||||
):
|
|
||||||
"""Test handling request with internal server error."""
|
"""Test handling request with internal server error."""
|
||||||
# Mock _convert_to_internal_request to raise an exception
|
# Mock _convert_to_internal_request to raise an exception
|
||||||
with patch.object(
|
with patch.object(
|
||||||
serving_embedding,
|
self.serving_embedding,
|
||||||
"_convert_to_internal_request",
|
"_convert_to_internal_request",
|
||||||
side_effect=Exception("Internal error"),
|
side_effect=Exception("Internal error"),
|
||||||
):
|
):
|
||||||
response = await serving_embedding.handle_request(
|
response = await self.serving_embedding.handle_request(
|
||||||
basic_embedding_request, mock_request
|
self.basic_req, self.request
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(response, ORJSONResponse)
|
self.assertIsInstance(response, ORJSONResponse)
|
||||||
assert response.status_code == 500
|
self.assertEqual(response.status_code, 500)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main(verbosity=2)
|
||||||
|
|||||||
@@ -62,6 +62,11 @@ suites = {
|
|||||||
TestFile("test_openai_adapter.py", 1),
|
TestFile("test_openai_adapter.py", 1),
|
||||||
TestFile("test_openai_function_calling.py", 60),
|
TestFile("test_openai_function_calling.py", 60),
|
||||||
TestFile("test_openai_server.py", 149),
|
TestFile("test_openai_server.py", 149),
|
||||||
|
TestFile("openai/test_server.py", 120),
|
||||||
|
TestFile("openai/test_protocol.py", 60),
|
||||||
|
TestFile("openai/test_serving_chat.py", 120),
|
||||||
|
TestFile("openai/test_serving_completions.py", 120),
|
||||||
|
TestFile("openai/test_serving_embedding.py", 120),
|
||||||
TestFile("test_openai_server_hidden_states.py", 240),
|
TestFile("test_openai_server_hidden_states.py", 240),
|
||||||
TestFile("test_penalty.py", 41),
|
TestFile("test_penalty.py", 41),
|
||||||
TestFile("test_page_size.py", 60),
|
TestFile("test_page_size.py", 60),
|
||||||
|
|||||||
Reference in New Issue
Block a user