feat(oai refactor): Replace openai_api with entrypoints/openai (#7351)

Co-authored-by: Jin Pan <jpan236@wisc.edu>
This commit is contained in:
Chang Su
2025-06-21 13:21:06 -07:00
committed by GitHub
parent 02bf31ef29
commit 72676cd6c0
43 changed files with 674 additions and 4555 deletions

View File

@@ -1,87 +0,0 @@
# sglang/test/srt/openai/conftest.py
import os
import socket
import subprocess
import sys
import tempfile
import time
from contextlib import closing
from typing import Generator
import pytest
import requests
from sglang.srt.utils import kill_process_tree # reuse SGLang helper
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
SERVER_MODULE = "sglang.srt.entrypoints.openai.api_server"
DEFAULT_MODEL = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
STARTUP_TIMEOUT = float(os.getenv("SGLANG_OPENAI_STARTUP_TIMEOUT", 120))
def _pick_free_port() -> int:
with closing(socket.socket()) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
def _wait_until_healthy(proc: subprocess.Popen, base: str, timeout: float) -> None:
start = time.perf_counter()
while time.perf_counter() - start < timeout:
if proc.poll() is not None: # crashed
raise RuntimeError("api_server terminated prematurely")
try:
if requests.get(f"{base}/health", timeout=1).status_code == 200:
return
except requests.RequestException:
pass
time.sleep(0.4)
raise RuntimeError("api_server readiness probe timed out")
def launch_openai_server(model: str = DEFAULT_MODEL, **kw):
"""Spawn the draft OpenAI-compatible server and wait until it's ready."""
port = _pick_free_port()
cmd = [
sys.executable,
"-m",
SERVER_MODULE,
"--model-path",
model,
"--host",
"127.0.0.1",
"--port",
str(port),
*map(str, kw.get("args", [])),
]
env = {**os.environ, **kw.get("env", {})}
# Write logs to a temp file so the child never blocks on a full pipe.
log_file = tempfile.NamedTemporaryFile("w+", delete=False)
proc = subprocess.Popen(
cmd,
env=env,
stdout=log_file,
stderr=subprocess.STDOUT,
text=True,
)
base = f"http://127.0.0.1:{port}"
try:
_wait_until_healthy(proc, base, STARTUP_TIMEOUT)
except Exception as e:
proc.terminate()
proc.wait(5)
log_file.seek(0)
print("\n--- api_server log ---\n", log_file.read(), file=sys.stderr)
raise e
return proc, base, log_file
@pytest.fixture(scope="session")
def openai_server() -> Generator[str, None, None]:
"""PyTest fixture that provides the server's base URL and cleans up."""
proc, base, log_file = launch_openai_server()
yield base
kill_process_tree(proc.pid)
log_file.close()

View File

@@ -67,29 +67,6 @@ from sglang.srt.entrypoints.openai.protocol import (
class TestModelCard(unittest.TestCase):
"""Test ModelCard protocol model"""
def test_basic_model_card_creation(self):
"""Test basic model card creation with required fields"""
card = ModelCard(id="test-model")
self.assertEqual(card.id, "test-model")
self.assertEqual(card.object, "model")
self.assertEqual(card.owned_by, "sglang")
self.assertIsInstance(card.created, int)
self.assertIsNone(card.root)
self.assertIsNone(card.max_model_len)
def test_model_card_with_optional_fields(self):
"""Test model card with optional fields"""
card = ModelCard(
id="test-model",
root="/path/to/model",
max_model_len=2048,
created=1234567890,
)
self.assertEqual(card.id, "test-model")
self.assertEqual(card.root, "/path/to/model")
self.assertEqual(card.max_model_len, 2048)
self.assertEqual(card.created, 1234567890)
def test_model_card_serialization(self):
"""Test model card JSON serialization"""
card = ModelCard(id="test-model", max_model_len=4096)
@@ -120,53 +97,6 @@ class TestModelList(unittest.TestCase):
self.assertEqual(model_list.data[1].id, "model-2")
class TestErrorResponse(unittest.TestCase):
"""Test ErrorResponse protocol model"""
def test_basic_error_response(self):
"""Test basic error response creation"""
error = ErrorResponse(
message="Invalid request", type="BadRequestError", code=400
)
self.assertEqual(error.object, "error")
self.assertEqual(error.message, "Invalid request")
self.assertEqual(error.type, "BadRequestError")
self.assertEqual(error.code, 400)
self.assertIsNone(error.param)
def test_error_response_with_param(self):
"""Test error response with parameter"""
error = ErrorResponse(
message="Invalid temperature",
type="ValidationError",
code=422,
param="temperature",
)
self.assertEqual(error.param, "temperature")
class TestUsageInfo(unittest.TestCase):
"""Test UsageInfo protocol model"""
def test_basic_usage_info(self):
"""Test basic usage info creation"""
usage = UsageInfo(prompt_tokens=10, completion_tokens=20, total_tokens=30)
self.assertEqual(usage.prompt_tokens, 10)
self.assertEqual(usage.completion_tokens, 20)
self.assertEqual(usage.total_tokens, 30)
self.assertIsNone(usage.prompt_tokens_details)
def test_usage_info_with_cache_details(self):
"""Test usage info with cache details"""
usage = UsageInfo(
prompt_tokens=10,
completion_tokens=20,
total_tokens=30,
prompt_tokens_details={"cached_tokens": 5},
)
self.assertEqual(usage.prompt_tokens_details, {"cached_tokens": 5})
class TestCompletionRequest(unittest.TestCase):
"""Test CompletionRequest protocol model"""
@@ -181,30 +111,6 @@ class TestCompletionRequest(unittest.TestCase):
self.assertFalse(request.stream) # default
self.assertFalse(request.echo) # default
def test_completion_request_with_options(self):
"""Test completion request with various options"""
request = CompletionRequest(
model="test-model",
prompt=["Hello", "world"],
max_tokens=100,
temperature=0.7,
top_p=0.9,
n=2,
stream=True,
echo=True,
stop=[".", "!"],
logprobs=5,
)
self.assertEqual(request.prompt, ["Hello", "world"])
self.assertEqual(request.max_tokens, 100)
self.assertEqual(request.temperature, 0.7)
self.assertEqual(request.top_p, 0.9)
self.assertEqual(request.n, 2)
self.assertTrue(request.stream)
self.assertTrue(request.echo)
self.assertEqual(request.stop, [".", "!"])
self.assertEqual(request.logprobs, 5)
def test_completion_request_sglang_extensions(self):
"""Test completion request with SGLang-specific extensions"""
request = CompletionRequest(
@@ -233,26 +139,6 @@ class TestCompletionRequest(unittest.TestCase):
CompletionRequest(model="test-model") # missing prompt
class TestCompletionResponse(unittest.TestCase):
"""Test CompletionResponse protocol model"""
def test_basic_completion_response(self):
"""Test basic completion response"""
choice = CompletionResponseChoice(
index=0, text="Hello world!", finish_reason="stop"
)
usage = UsageInfo(prompt_tokens=2, completion_tokens=3, total_tokens=5)
response = CompletionResponse(
id="test-id", model="test-model", choices=[choice], usage=usage
)
self.assertEqual(response.id, "test-id")
self.assertEqual(response.object, "text_completion")
self.assertEqual(response.model, "test-model")
self.assertEqual(len(response.choices), 1)
self.assertEqual(response.choices[0].text, "Hello world!")
self.assertEqual(response.usage.total_tokens, 5)
class TestChatCompletionRequest(unittest.TestCase):
"""Test ChatCompletionRequest protocol model"""
@@ -268,48 +154,6 @@ class TestChatCompletionRequest(unittest.TestCase):
self.assertFalse(request.stream) # default
self.assertEqual(request.tool_choice, "none") # default when no tools
def test_chat_completion_with_multimodal_content(self):
"""Test chat completion with multimodal content"""
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,/9j/4AAQ..."},
},
],
}
]
request = ChatCompletionRequest(model="test-model", messages=messages)
self.assertEqual(len(request.messages[0].content), 2)
self.assertEqual(request.messages[0].content[0].type, "text")
self.assertEqual(request.messages[0].content[1].type, "image_url")
def test_chat_completion_with_tools(self):
"""Test chat completion with tools"""
messages = [{"role": "user", "content": "What's the weather?"}]
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather information",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
},
},
}
]
request = ChatCompletionRequest(
model="test-model", messages=messages, tools=tools
)
self.assertEqual(len(request.tools), 1)
self.assertEqual(request.tools[0].function.name, "get_weather")
self.assertEqual(request.tool_choice, "auto") # default when tools present
def test_chat_completion_tool_choice_validation(self):
"""Test tool choice validation logic"""
messages = [{"role": "user", "content": "Hello"}]
@@ -349,289 +193,6 @@ class TestChatCompletionRequest(unittest.TestCase):
self.assertEqual(request.chat_template_kwargs, {"custom_param": "value"})
class TestChatCompletionResponse(unittest.TestCase):
"""Test ChatCompletionResponse protocol model"""
def test_basic_chat_completion_response(self):
"""Test basic chat completion response"""
message = ChatMessage(role="assistant", content="Hello there!")
choice = ChatCompletionResponseChoice(
index=0, message=message, finish_reason="stop"
)
usage = UsageInfo(prompt_tokens=2, completion_tokens=3, total_tokens=5)
response = ChatCompletionResponse(
id="test-id", model="test-model", choices=[choice], usage=usage
)
self.assertEqual(response.id, "test-id")
self.assertEqual(response.object, "chat.completion")
self.assertEqual(response.model, "test-model")
self.assertEqual(len(response.choices), 1)
self.assertEqual(response.choices[0].message.content, "Hello there!")
def test_chat_completion_response_with_tool_calls(self):
"""Test chat completion response with tool calls"""
tool_call = ToolCall(
id="call_123",
function=FunctionResponse(
name="get_weather", arguments='{"location": "San Francisco"}'
),
)
message = ChatMessage(role="assistant", content=None, tool_calls=[tool_call])
choice = ChatCompletionResponseChoice(
index=0, message=message, finish_reason="tool_calls"
)
usage = UsageInfo(prompt_tokens=10, completion_tokens=5, total_tokens=15)
response = ChatCompletionResponse(
id="test-id", model="test-model", choices=[choice], usage=usage
)
self.assertEqual(
response.choices[0].message.tool_calls[0].function.name, "get_weather"
)
self.assertEqual(response.choices[0].finish_reason, "tool_calls")
class TestEmbeddingRequest(unittest.TestCase):
"""Test EmbeddingRequest protocol model"""
def test_basic_embedding_request(self):
"""Test basic embedding request"""
request = EmbeddingRequest(model="test-model", input="Hello world")
self.assertEqual(request.model, "test-model")
self.assertEqual(request.input, "Hello world")
self.assertEqual(request.encoding_format, "float") # default
self.assertIsNone(request.dimensions) # default
def test_embedding_request_with_list_input(self):
"""Test embedding request with list input"""
request = EmbeddingRequest(
model="test-model", input=["Hello", "world"], dimensions=512
)
self.assertEqual(request.input, ["Hello", "world"])
self.assertEqual(request.dimensions, 512)
def test_multimodal_embedding_request(self):
"""Test multimodal embedding request"""
multimodal_input = [
MultimodalEmbeddingInput(text="Hello", image="base64_image_data"),
MultimodalEmbeddingInput(text="World", image=None),
]
request = EmbeddingRequest(model="test-model", input=multimodal_input)
self.assertEqual(len(request.input), 2)
self.assertEqual(request.input[0].text, "Hello")
self.assertEqual(request.input[0].image, "base64_image_data")
self.assertEqual(request.input[1].text, "World")
self.assertIsNone(request.input[1].image)
class TestEmbeddingResponse(unittest.TestCase):
"""Test EmbeddingResponse protocol model"""
def test_basic_embedding_response(self):
"""Test basic embedding response"""
embedding_obj = EmbeddingObject(embedding=[0.1, 0.2, 0.3], index=0)
usage = UsageInfo(prompt_tokens=3, total_tokens=3)
response = EmbeddingResponse(
data=[embedding_obj], model="test-model", usage=usage
)
self.assertEqual(response.object, "list")
self.assertEqual(len(response.data), 1)
self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3])
self.assertEqual(response.data[0].index, 0)
self.assertEqual(response.usage.prompt_tokens, 3)
class TestScoringRequest(unittest.TestCase):
"""Test ScoringRequest protocol model"""
def test_basic_scoring_request(self):
"""Test basic scoring request"""
request = ScoringRequest(
model="test-model", query="Hello", items=["World", "Earth"]
)
self.assertEqual(request.model, "test-model")
self.assertEqual(request.query, "Hello")
self.assertEqual(request.items, ["World", "Earth"])
self.assertFalse(request.apply_softmax) # default
self.assertFalse(request.item_first) # default
def test_scoring_request_with_token_ids(self):
"""Test scoring request with token IDs"""
request = ScoringRequest(
model="test-model",
query=[1, 2, 3],
items=[[4, 5], [6, 7]],
label_token_ids=[8, 9],
apply_softmax=True,
item_first=True,
)
self.assertEqual(request.query, [1, 2, 3])
self.assertEqual(request.items, [[4, 5], [6, 7]])
self.assertEqual(request.label_token_ids, [8, 9])
self.assertTrue(request.apply_softmax)
self.assertTrue(request.item_first)
class TestScoringResponse(unittest.TestCase):
"""Test ScoringResponse protocol model"""
def test_basic_scoring_response(self):
"""Test basic scoring response"""
response = ScoringResponse(scores=[[0.1, 0.9], [0.3, 0.7]], model="test-model")
self.assertEqual(response.object, "scoring")
self.assertEqual(response.scores, [[0.1, 0.9], [0.3, 0.7]])
self.assertEqual(response.model, "test-model")
self.assertIsNone(response.usage) # default
class TestFileOperations(unittest.TestCase):
"""Test file operation protocol models"""
def test_file_request(self):
"""Test file request model"""
file_data = b"test file content"
request = FileRequest(file=file_data, purpose="batch")
self.assertEqual(request.file, file_data)
self.assertEqual(request.purpose, "batch")
def test_file_response(self):
"""Test file response model"""
response = FileResponse(
id="file-123",
bytes=1024,
created_at=1234567890,
filename="test.jsonl",
purpose="batch",
)
self.assertEqual(response.id, "file-123")
self.assertEqual(response.object, "file")
self.assertEqual(response.bytes, 1024)
self.assertEqual(response.filename, "test.jsonl")
def test_file_delete_response(self):
"""Test file delete response model"""
response = FileDeleteResponse(id="file-123", deleted=True)
self.assertEqual(response.id, "file-123")
self.assertEqual(response.object, "file")
self.assertTrue(response.deleted)
class TestBatchOperations(unittest.TestCase):
"""Test batch operation protocol models"""
def test_batch_request(self):
"""Test batch request model"""
request = BatchRequest(
input_file_id="file-123",
endpoint="/v1/chat/completions",
completion_window="24h",
metadata={"custom": "value"},
)
self.assertEqual(request.input_file_id, "file-123")
self.assertEqual(request.endpoint, "/v1/chat/completions")
self.assertEqual(request.completion_window, "24h")
self.assertEqual(request.metadata, {"custom": "value"})
def test_batch_response(self):
"""Test batch response model"""
response = BatchResponse(
id="batch-123",
endpoint="/v1/chat/completions",
input_file_id="file-123",
completion_window="24h",
created_at=1234567890,
)
self.assertEqual(response.id, "batch-123")
self.assertEqual(response.object, "batch")
self.assertEqual(response.status, "validating") # default
self.assertEqual(response.endpoint, "/v1/chat/completions")
class TestResponseFormats(unittest.TestCase):
"""Test response format protocol models"""
def test_basic_response_format(self):
"""Test basic response format"""
format_obj = ResponseFormat(type="json_object")
self.assertEqual(format_obj.type, "json_object")
self.assertIsNone(format_obj.json_schema)
def test_json_schema_response_format(self):
"""Test JSON schema response format"""
schema = {"type": "object", "properties": {"name": {"type": "string"}}}
json_schema = JsonSchemaResponseFormat(
name="person_schema", description="Person schema", schema=schema
)
format_obj = ResponseFormat(type="json_schema", json_schema=json_schema)
self.assertEqual(format_obj.type, "json_schema")
self.assertEqual(format_obj.json_schema.name, "person_schema")
self.assertEqual(format_obj.json_schema.schema_, schema)
def test_structural_tag_response_format(self):
"""Test structural tag response format"""
structures = [
{
"begin": "<thinking>",
"schema_": {"type": "string"},
"end": "</thinking>",
}
]
format_obj = StructuralTagResponseFormat(
type="structural_tag", structures=structures, triggers=["think"]
)
self.assertEqual(format_obj.type, "structural_tag")
self.assertEqual(len(format_obj.structures), 1)
self.assertEqual(format_obj.triggers, ["think"])
class TestLogProbs(unittest.TestCase):
"""Test LogProbs protocol models"""
def test_basic_logprobs(self):
"""Test basic LogProbs model"""
logprobs = LogProbs(
text_offset=[0, 5, 11],
token_logprobs=[-0.1, -0.2, -0.3],
tokens=["Hello", " ", "world"],
top_logprobs=[{"Hello": -0.1}, {" ": -0.2}, {"world": -0.3}],
)
self.assertEqual(len(logprobs.tokens), 3)
self.assertEqual(logprobs.tokens, ["Hello", " ", "world"])
self.assertEqual(logprobs.token_logprobs, [-0.1, -0.2, -0.3])
def test_choice_logprobs(self):
"""Test ChoiceLogprobs model"""
token_logprob = ChatCompletionTokenLogprob(
token="Hello",
bytes=[72, 101, 108, 108, 111],
logprob=-0.1,
top_logprobs=[
TopLogprob(token="Hello", bytes=[72, 101, 108, 108, 111], logprob=-0.1)
],
)
choice_logprobs = ChoiceLogprobs(content=[token_logprob])
self.assertEqual(len(choice_logprobs.content), 1)
self.assertEqual(choice_logprobs.content[0].token, "Hello")
class TestStreamingModels(unittest.TestCase):
"""Test streaming response models"""
def test_stream_options(self):
"""Test StreamOptions model"""
options = StreamOptions(include_usage=True)
self.assertTrue(options.include_usage)
def test_chat_completion_stream_response(self):
"""Test ChatCompletionStreamResponse model"""
delta = DeltaMessage(role="assistant", content="Hello")
choice = ChatCompletionResponseStreamChoice(index=0, delta=delta)
response = ChatCompletionStreamResponse(
id="test-id", model="test-model", choices=[choice]
)
self.assertEqual(response.object, "chat.completion.chunk")
self.assertEqual(response.choices[0].delta.content, "Hello")
class TestModelSerialization(unittest.TestCase):
"""Test model serialization with hidden states"""
@@ -680,11 +241,6 @@ class TestModelSerialization(unittest.TestCase):
class TestValidationEdgeCases(unittest.TestCase):
"""Test edge cases and validation scenarios"""
def test_empty_messages_validation(self):
"""Test validation with empty messages"""
with self.assertRaises(ValidationError):
ChatCompletionRequest(model="test-model", messages=[])
def test_invalid_tool_choice_type(self):
"""Test invalid tool choice type"""
messages = [{"role": "user", "content": "Hello"}]
@@ -698,13 +254,6 @@ class TestValidationEdgeCases(unittest.TestCase):
with self.assertRaises(ValidationError):
CompletionRequest(model="test-model", prompt="Hello", max_tokens=-1)
def test_invalid_temperature_range(self):
"""Test invalid temperature values"""
# Note: The current protocol doesn't enforce temperature range,
# but this test documents expected behavior
request = CompletionRequest(model="test-model", prompt="Hello", temperature=5.0)
self.assertEqual(request.temperature, 5.0) # Currently allowed
def test_model_serialization_roundtrip(self):
"""Test that models can be serialized and deserialized"""
original_request = ChatCompletionRequest(

View File

@@ -1,52 +0,0 @@
# sglang/test/srt/openai/test_server.py
import requests
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST as MODEL_ID
def test_health(openai_server: str):
r = requests.get(f"{openai_server}/health")
assert r.status_code == 200
# FastAPI returns an empty body → r.text == ""
assert r.text == ""
def test_models_endpoint(openai_server: str):
r = requests.get(f"{openai_server}/v1/models")
assert r.status_code == 200, r.text
payload = r.json()
# Basic contract
assert "data" in payload and isinstance(payload["data"], list) and payload["data"]
# Validate fields of the first model card
first = payload["data"][0]
for key in ("id", "root", "max_model_len"):
assert key in first, f"missing {key} in {first}"
# max_model_len must be positive
assert isinstance(first["max_model_len"], int) and first["max_model_len"] > 0
# The server should report the same model id we launched it with
ids = {m["id"] for m in payload["data"]}
assert MODEL_ID in ids
def test_get_model_info(openai_server: str):
r = requests.get(f"{openai_server}/get_model_info")
assert r.status_code == 200, r.text
info = r.json()
expected_keys = {"model_path", "tokenizer_path", "is_generation"}
assert expected_keys.issubset(info.keys())
# model_path must end with the one we passed on the CLI
assert info["model_path"].endswith(MODEL_ID)
# is_generation is documented as a boolean
assert isinstance(info["is_generation"], bool)
def test_unknown_route_returns_404(openai_server: str):
r = requests.get(f"{openai_server}/definitely-not-a-real-route")
assert r.status_code == 404

View File

@@ -57,11 +57,21 @@ class _MockTokenizerManager:
self.create_abort_task = Mock()
class _MockTemplateManager:
"""Minimal mock for TemplateManager."""
def __init__(self):
self.chat_template_name: Optional[str] = "llama-3"
self.jinja_template_content_format: Optional[str] = None
self.completion_template_name: Optional[str] = None
class ServingChatTestCase(unittest.TestCase):
# ------------- common fixtures -------------
def setUp(self):
self.tm = _MockTokenizerManager()
self.chat = OpenAIServingChat(self.tm)
self.template_manager = _MockTemplateManager()
self.chat = OpenAIServingChat(self.tm, self.template_manager)
# frequently reused requests
self.basic_req = ChatCompletionRequest(
@@ -109,96 +119,6 @@ class ServingChatTestCase(unittest.TestCase):
self.assertFalse(adapted.stream)
self.assertEqual(processed, self.basic_req)
# # ------------- tool-call branch -------------
# def test_tool_call_request_conversion(self):
# req = ChatCompletionRequest(
# model="x",
# messages=[{"role": "user", "content": "Weather?"}],
# tools=[
# {
# "type": "function",
# "function": {
# "name": "get_weather",
# "parameters": {"type": "object", "properties": {}},
# },
# }
# ],
# tool_choice="auto",
# )
# with patch.object(
# self.chat,
# "_process_messages",
# return_value=("Prompt", [1, 2, 3], None, None, [], ["</s>"], None),
# ):
# adapted, _ = self.chat._convert_to_internal_request(req, "rid")
# self.assertEqual(adapted.rid, "rid")
# def test_tool_choice_none(self):
# req = ChatCompletionRequest(
# model="x",
# messages=[{"role": "user", "content": "Hi"}],
# tools=[{"type": "function", "function": {"name": "noop"}}],
# tool_choice="none",
# )
# with patch.object(
# self.chat,
# "_process_messages",
# return_value=("Prompt", [1, 2, 3], None, None, [], ["</s>"], None),
# ):
# adapted, _ = self.chat._convert_to_internal_request(req, "rid")
# self.assertEqual(adapted.rid, "rid")
# ------------- multimodal branch -------------
def test_multimodal_request_with_images(self):
self.tm.model_config.is_multimodal = True
req = ChatCompletionRequest(
model="x",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "What's in the image?"},
{
"type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,"},
},
],
}
],
)
with patch.object(
self.chat,
"_apply_jinja_template",
return_value=("prompt", [1, 2], ["img"], None, [], []),
), patch.object(
self.chat,
"_apply_conversation_template",
return_value=("prompt", ["img"], None, [], []),
):
out = self.chat._process_messages(req, True)
_, _, image_data, *_ = out
self.assertEqual(image_data, ["img"])
# ------------- template handling -------------
def test_jinja_template_processing(self):
req = ChatCompletionRequest(
model="x", messages=[{"role": "user", "content": "Hello"}]
)
self.tm.chat_template_name = None
self.tm.tokenizer.chat_template = "<jinja>"
with patch.object(
self.chat,
"_apply_jinja_template",
return_value=("processed", [1], None, None, [], ["</s>"]),
), patch("builtins.hasattr", return_value=True):
prompt, prompt_ids, *_ = self.chat._process_messages(req, False)
self.assertEqual(prompt, "processed")
self.assertEqual(prompt_ids, [1])
# ------------- sampling-params -------------
def test_sampling_param_build(self):
req = ChatCompletionRequest(

View File

@@ -5,6 +5,7 @@ Run with:
"""
import unittest
from typing import Optional
from unittest.mock import AsyncMock, Mock, patch
from sglang.srt.entrypoints.openai.protocol import CompletionRequest
@@ -12,6 +13,17 @@ from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompl
from sglang.srt.managers.tokenizer_manager import TokenizerManager
class _MockTemplateManager:
"""Minimal mock for TemplateManager."""
def __init__(self):
self.chat_template_name: Optional[str] = None
self.jinja_template_content_format: Optional[str] = None
self.completion_template_name: Optional[str] = (
None # Set to None to avoid template processing
)
class ServingCompletionTestCase(unittest.TestCase):
"""Bundle all prompt/echo tests in one TestCase."""
@@ -31,7 +43,8 @@ class ServingCompletionTestCase(unittest.TestCase):
tm.generate_request = AsyncMock()
tm.create_abort_task = Mock()
self.sc = OpenAIServingCompletion(tm)
self.template_manager = _MockTemplateManager()
self.sc = OpenAIServingCompletion(tm, self.template_manager)
# ---------- prompt-handling ----------
def test_single_string_prompt(self):
@@ -44,20 +57,6 @@ class ServingCompletionTestCase(unittest.TestCase):
internal, _ = self.sc._convert_to_internal_request(req)
self.assertEqual(internal.input_ids, [1, 2, 3, 4])
def test_completion_template_handling(self):
req = CompletionRequest(
model="x", prompt="def f():", suffix="return 1", max_tokens=100
)
with patch(
"sglang.srt.entrypoints.openai.serving_completions.is_completion_template_defined",
return_value=True,
), patch(
"sglang.srt.entrypoints.openai.serving_completions.generate_completion_prompt_from_request",
return_value="processed_prompt",
):
internal, _ = self.sc._convert_to_internal_request(req)
self.assertEqual(internal.text, "processed_prompt")
# ---------- echo-handling ----------
def test_echo_with_string_prompt_streaming(self):
req = CompletionRequest(model="x", prompt="Hello", max_tokens=1, echo=True)

View File

@@ -5,25 +5,16 @@ These tests ensure that the embedding serving implementation maintains compatibi
with the original adapter.py functionality and follows OpenAI API specifications.
"""
import asyncio
import json
import time
import unittest
import uuid
from typing import Any, Dict, List
from unittest.mock import AsyncMock, Mock, patch
from unittest.mock import Mock
from fastapi import Request
from fastapi.responses import ORJSONResponse
from pydantic_core import ValidationError
from sglang.srt.entrypoints.openai.protocol import (
EmbeddingObject,
EmbeddingRequest,
EmbeddingResponse,
ErrorResponse,
MultimodalEmbeddingInput,
UsageInfo,
)
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from sglang.srt.managers.io_struct import EmbeddingReqInput
@@ -58,11 +49,22 @@ class _MockTokenizerManager:
self.generate_request = Mock(return_value=mock_generate_embedding())
# Mock TemplateManager for embedding tests
class _MockTemplateManager:
def __init__(self):
self.chat_template_name = None # None for embeddings usually
self.jinja_template_content_format = None
self.completion_template_name = None
class ServingEmbeddingTestCase(unittest.TestCase):
def setUp(self):
"""Set up test fixtures."""
self.tokenizer_manager = _MockTokenizerManager()
self.serving_embedding = OpenAIServingEmbedding(self.tokenizer_manager)
self.template_manager = _MockTemplateManager()
self.serving_embedding = OpenAIServingEmbedding(
self.tokenizer_manager, self.template_manager
)
self.request = Mock(spec=Request)
self.request.headers = {}
@@ -141,132 +143,6 @@ class ServingEmbeddingTestCase(unittest.TestCase):
self.assertIsNone(adapted_request.image_data[1])
# self.assertEqual(adapted_request.rid, "test-id")
def test_build_single_embedding_response(self):
"""Test building response for single embedding."""
ret_data = [
{
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
"meta_info": {"prompt_tokens": 5},
}
]
response = self.serving_embedding._build_embedding_response(ret_data)
self.assertIsInstance(response, EmbeddingResponse)
self.assertEqual(response.model, "test-model")
self.assertEqual(len(response.data), 1)
self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3, 0.4, 0.5])
self.assertEqual(response.data[0].index, 0)
self.assertEqual(response.data[0].object, "embedding")
self.assertEqual(response.usage.prompt_tokens, 5)
self.assertEqual(response.usage.total_tokens, 5)
self.assertEqual(response.usage.completion_tokens, 0)
def test_build_multiple_embedding_response(self):
"""Test building response for multiple embeddings."""
ret_data = [
{
"embedding": [0.1, 0.2, 0.3],
"meta_info": {"prompt_tokens": 3},
},
{
"embedding": [0.4, 0.5, 0.6],
"meta_info": {"prompt_tokens": 4},
},
]
response = self.serving_embedding._build_embedding_response(ret_data)
self.assertIsInstance(response, EmbeddingResponse)
self.assertEqual(len(response.data), 2)
self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3])
self.assertEqual(response.data[0].index, 0)
self.assertEqual(response.data[1].embedding, [0.4, 0.5, 0.6])
self.assertEqual(response.data[1].index, 1)
self.assertEqual(response.usage.prompt_tokens, 7) # 3 + 4
self.assertEqual(response.usage.total_tokens, 7)
def test_handle_request_success(self):
"""Test successful embedding request handling."""
async def run_test():
# Mock the generate_request to return expected data
async def mock_generate():
yield {
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
"meta_info": {"prompt_tokens": 5},
}
self.serving_embedding.tokenizer_manager.generate_request = Mock(
return_value=mock_generate()
)
response = await self.serving_embedding.handle_request(
self.basic_req, self.request
)
self.assertIsInstance(response, EmbeddingResponse)
self.assertEqual(len(response.data), 1)
self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3, 0.4, 0.5])
asyncio.run(run_test())
def test_handle_request_validation_error(self):
"""Test handling request with validation error."""
async def run_test():
invalid_request = EmbeddingRequest(model="test-model", input="")
response = await self.serving_embedding.handle_request(
invalid_request, self.request
)
self.assertIsInstance(response, ORJSONResponse)
self.assertEqual(response.status_code, 400)
asyncio.run(run_test())
def test_handle_request_generation_error(self):
"""Test handling request with generation error."""
async def run_test():
# Mock generate_request to raise an error
async def mock_generate_error():
raise ValueError("Generation failed")
yield # This won't be reached but needed for async generator
self.serving_embedding.tokenizer_manager.generate_request = Mock(
return_value=mock_generate_error()
)
response = await self.serving_embedding.handle_request(
self.basic_req, self.request
)
self.assertIsInstance(response, ORJSONResponse)
self.assertEqual(response.status_code, 400)
asyncio.run(run_test())
def test_handle_request_internal_error(self):
"""Test handling request with internal server error."""
async def run_test():
# Mock _convert_to_internal_request to raise an exception
with patch.object(
self.serving_embedding,
"_convert_to_internal_request",
side_effect=Exception("Internal error"),
):
response = await self.serving_embedding.handle_request(
self.basic_req, self.request
)
self.assertIsInstance(response, ORJSONResponse)
self.assertEqual(response.status_code, 500)
asyncio.run(run_test())
if __name__ == "__main__":
unittest.main(verbosity=2)

View File

@@ -29,6 +29,10 @@ suites = {
TestFile("models/test_reward_models.py", 132),
TestFile("models/test_vlm_models.py", 437),
TestFile("models/test_transformers_models.py", 320),
TestFile("openai/test_protocol.py", 10),
TestFile("openai/test_serving_chat.py", 10),
TestFile("openai/test_serving_completions.py", 10),
TestFile("openai/test_serving_embedding.py", 10),
TestFile("test_abort.py", 51),
TestFile("test_block_int8.py", 22),
TestFile("test_create_kvindices.py", 2),
@@ -49,6 +53,7 @@ suites = {
TestFile("test_hidden_states.py", 55),
TestFile("test_int8_kernel.py", 8),
TestFile("test_input_embeddings.py", 38),
TestFile("test_jinja_template_utils.py", 1),
TestFile("test_json_constrained.py", 98),
TestFile("test_large_max_new_tokens.py", 41),
TestFile("test_metrics.py", 32),
@@ -59,14 +64,8 @@ suites = {
TestFile("test_mla_fp8.py", 93),
TestFile("test_no_chunked_prefill.py", 108),
TestFile("test_no_overlap_scheduler.py", 234),
TestFile("test_openai_adapter.py", 1),
TestFile("test_openai_function_calling.py", 60),
TestFile("test_openai_server.py", 149),
TestFile("openai/test_server.py", 120),
TestFile("openai/test_protocol.py", 60),
TestFile("openai/test_serving_chat.py", 120),
TestFile("openai/test_serving_completions.py", 120),
TestFile("openai/test_serving_embedding.py", 120),
TestFile("test_openai_server_hidden_states.py", 240),
TestFile("test_penalty.py", 41),
TestFile("test_page_size.py", 60),

View File

@@ -3,6 +3,7 @@ import unittest
from xgrammar import GrammarCompiler, TokenizerInfo
from sglang.srt.entrypoints.openai.protocol import Function, Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
from sglang.srt.function_call.llama32_detector import Llama32Detector
@@ -10,7 +11,6 @@ from sglang.srt.function_call.mistral_detector import MistralDetector
from sglang.srt.function_call.pythonic_detector import PythonicDetector
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.openai_api.protocol import Function, Tool
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST

View File

@@ -5,8 +5,8 @@ Unit tests for OpenAI adapter utils.
import unittest
from unittest.mock import patch
from sglang.srt.openai_api.utils import (
detect_template_content_format,
from sglang.srt.jinja_template_utils import (
detect_jinja_template_content_format,
process_content_for_template_format,
)
from sglang.test.test_utils import CustomTestCase
@@ -33,7 +33,7 @@ class TestTemplateContentFormatDetection(CustomTestCase):
{%- endfor %}
"""
result = detect_template_content_format(llama4_pattern)
result = detect_jinja_template_content_format(llama4_pattern)
self.assertEqual(result, "openai")
def test_detect_deepseek_string_format(self):
@@ -46,19 +46,19 @@ class TestTemplateContentFormatDetection(CustomTestCase):
{%- endfor %}
"""
result = detect_template_content_format(deepseek_pattern)
result = detect_jinja_template_content_format(deepseek_pattern)
self.assertEqual(result, "string")
def test_detect_invalid_template(self):
"""Test handling of invalid template (should default to 'string')."""
invalid_pattern = "{{{{ invalid jinja syntax }}}}"
result = detect_template_content_format(invalid_pattern)
result = detect_jinja_template_content_format(invalid_pattern)
self.assertEqual(result, "string")
def test_detect_empty_template(self):
"""Test handling of empty template (should default to 'string')."""
result = detect_template_content_format("")
result = detect_jinja_template_content_format("")
self.assertEqual(result, "string")
def test_process_content_openai_format(self):

View File

@@ -235,6 +235,7 @@ class TestOpenAIServer(CustomTestCase):
)
is_firsts = {}
is_finished = {}
for response in generator:
usage = response.usage
if usage is not None:
@@ -244,6 +245,10 @@ class TestOpenAIServer(CustomTestCase):
continue
index = response.choices[0].index
finish_reason = response.choices[0].finish_reason
if finish_reason is not None:
is_finished[index] = True
data = response.choices[0].delta
if is_firsts.get(index, True):
@@ -253,7 +258,7 @@ class TestOpenAIServer(CustomTestCase):
is_firsts[index] = False
continue
if logprobs:
if logprobs and not is_finished.get(index, False):
assert response.choices[0].logprobs, f"logprobs was not returned"
assert isinstance(
response.choices[0].logprobs.content[0].top_logprobs[0].token, str
@@ -271,7 +276,7 @@ class TestOpenAIServer(CustomTestCase):
assert (
isinstance(data.content, str)
or isinstance(data.reasoning_content, str)
or len(data.tool_calls) > 0
or (isinstance(data.tool_calls, list) and len(data.tool_calls) > 0)
or response.choices[0].finish_reason
)
assert response.id
@@ -282,152 +287,6 @@ class TestOpenAIServer(CustomTestCase):
index, True
), f"index {index} is not found in the response"
def _create_batch(self, mode, client):
if mode == "completion":
input_file_path = "complete_input.jsonl"
# write content to input file
content = [
{
"custom_id": "request-1",
"method": "POST",
"url": "/v1/completions",
"body": {
"model": "gpt-3.5-turbo-instruct",
"prompt": "List 3 names of famous soccer player: ",
"max_tokens": 20,
},
},
{
"custom_id": "request-2",
"method": "POST",
"url": "/v1/completions",
"body": {
"model": "gpt-3.5-turbo-instruct",
"prompt": "List 6 names of famous basketball player: ",
"max_tokens": 40,
},
},
{
"custom_id": "request-3",
"method": "POST",
"url": "/v1/completions",
"body": {
"model": "gpt-3.5-turbo-instruct",
"prompt": "List 6 names of famous tenniss player: ",
"max_tokens": 40,
},
},
]
else:
input_file_path = "chat_input.jsonl"
content = [
{
"custom_id": "request-1",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": "gpt-3.5-turbo-0125",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant.",
},
{
"role": "user",
"content": "Hello! List 3 NBA players and tell a story",
},
],
"max_tokens": 30,
},
},
{
"custom_id": "request-2",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": "gpt-3.5-turbo-0125",
"messages": [
{"role": "system", "content": "You are an assistant. "},
{
"role": "user",
"content": "Hello! List three capital and tell a story",
},
],
"max_tokens": 50,
},
},
]
with open(input_file_path, "w") as file:
for line in content:
file.write(json.dumps(line) + "\n")
with open(input_file_path, "rb") as file:
uploaded_file = client.files.create(file=file, purpose="batch")
if mode == "completion":
endpoint = "/v1/completions"
elif mode == "chat":
endpoint = "/v1/chat/completions"
completion_window = "24h"
batch_job = client.batches.create(
input_file_id=uploaded_file.id,
endpoint=endpoint,
completion_window=completion_window,
)
return batch_job, content, uploaded_file
def run_batch(self, mode):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
batch_job, content, uploaded_file = self._create_batch(mode=mode, client=client)
while batch_job.status not in ["completed", "failed", "cancelled"]:
time.sleep(3)
print(
f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
)
batch_job = client.batches.retrieve(batch_job.id)
assert (
batch_job.status == "completed"
), f"Batch job status is not completed: {batch_job.status}"
assert batch_job.request_counts.completed == len(content)
assert batch_job.request_counts.failed == 0
assert batch_job.request_counts.total == len(content)
result_file_id = batch_job.output_file_id
file_response = client.files.content(result_file_id)
result_content = file_response.read().decode("utf-8") # Decode bytes to string
results = [
json.loads(line)
for line in result_content.split("\n")
if line.strip() != ""
]
assert len(results) == len(content)
for delete_fid in [uploaded_file.id, result_file_id]:
del_pesponse = client.files.delete(delete_fid)
assert del_pesponse.deleted
def run_cancel_batch(self, mode):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
batch_job, _, uploaded_file = self._create_batch(mode=mode, client=client)
assert batch_job.status not in ["cancelling", "cancelled"]
batch_job = client.batches.cancel(batch_id=batch_job.id)
assert batch_job.status == "cancelling"
while batch_job.status not in ["failed", "cancelled"]:
batch_job = client.batches.retrieve(batch_job.id)
print(
f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
)
time.sleep(3)
assert batch_job.status == "cancelled"
del_response = client.files.delete(uploaded_file.id)
assert del_response.deleted
def test_completion(self):
for echo in [False, True]:
for logprobs in [None, 5]:
@@ -467,14 +326,6 @@ class TestOpenAIServer(CustomTestCase):
for parallel_sample_num in [1, 2]:
self.run_chat_completion_stream(logprobs, parallel_sample_num)
def test_batch(self):
for mode in ["completion", "chat"]:
self.run_batch(mode)
def test_cancel_batch(self):
for mode in ["completion", "chat"]:
self.run_cancel_batch(mode)
def test_regex(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
@@ -559,6 +410,18 @@ The SmartHome Mini is a compact smart home assistant available in black or white
assert len(models) == 1
assert isinstance(getattr(models[0], "max_model_len", None), int)
def test_retrieve_model(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
# Test retrieving an existing model
retrieved_model = client.models.retrieve(self.model)
self.assertEqual(retrieved_model.id, self.model)
self.assertEqual(retrieved_model.root, self.model)
# Test retrieving a non-existent model
with self.assertRaises(openai.NotFoundError):
client.models.retrieve("non-existent-model")
# -------------------------------------------------------------------------
# EBNF Test Class: TestOpenAIServerEBNF
@@ -684,6 +547,31 @@ class TestOpenAIEmbedding(CustomTestCase):
self.assertTrue(len(response.data[0].embedding) > 0)
self.assertTrue(len(response.data[1].embedding) > 0)
def test_embedding_single_batch_str(self):
"""Test embedding with a List[str] and length equals to 1"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.embeddings.create(model=self.model, input=["Hello world"])
self.assertEqual(len(response.data), 1)
self.assertTrue(len(response.data[0].embedding) > 0)
def test_embedding_single_int_list(self):
"""Test embedding with a List[int] or List[List[int]]]"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.embeddings.create(
model=self.model,
input=[[15339, 314, 703, 284, 612, 262, 10658, 10188, 286, 2061]],
)
self.assertEqual(len(response.data), 1)
self.assertTrue(len(response.data[0].embedding) > 0)
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.embeddings.create(
model=self.model,
input=[15339, 314, 703, 284, 612, 262, 10658, 10188, 286, 2061],
)
self.assertEqual(len(response.data), 1)
self.assertTrue(len(response.data[0].embedding) > 0)
def test_empty_string_embedding(self):
"""Test embedding an empty string."""

View File

@@ -21,6 +21,7 @@ from transformers import (
from sglang import Engine
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.conversation import generate_chat_conv
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
from sglang.srt.managers.mm_utils import embed_mm_inputs, init_embedding_cache
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
@@ -31,7 +32,6 @@ from sglang.srt.managers.schedule_batch import (
MultimodalInputs,
)
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.openai_api.protocol import ChatCompletionRequest
from sglang.srt.server_args import ServerArgs

View File

@@ -15,7 +15,7 @@ from transformers import (
from sglang import Engine
from sglang.srt.conversation import generate_chat_conv
from sglang.srt.openai_api.protocol import ChatCompletionRequest
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
TEST_IMAGE_URL = "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"