feat(oai refactor): Replace openai_api with entrypoints/openai (#7351)
Co-authored-by: Jin Pan <jpan236@wisc.edu>
This commit is contained in:
@@ -1,87 +0,0 @@
|
||||
# sglang/test/srt/openai/conftest.py
|
||||
import os
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from contextlib import closing
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from sglang.srt.utils import kill_process_tree # reuse SGLang helper
|
||||
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
|
||||
SERVER_MODULE = "sglang.srt.entrypoints.openai.api_server"
|
||||
DEFAULT_MODEL = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
STARTUP_TIMEOUT = float(os.getenv("SGLANG_OPENAI_STARTUP_TIMEOUT", 120))
|
||||
|
||||
|
||||
def _pick_free_port() -> int:
|
||||
with closing(socket.socket()) as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def _wait_until_healthy(proc: subprocess.Popen, base: str, timeout: float) -> None:
|
||||
start = time.perf_counter()
|
||||
while time.perf_counter() - start < timeout:
|
||||
if proc.poll() is not None: # crashed
|
||||
raise RuntimeError("api_server terminated prematurely")
|
||||
try:
|
||||
if requests.get(f"{base}/health", timeout=1).status_code == 200:
|
||||
return
|
||||
except requests.RequestException:
|
||||
pass
|
||||
time.sleep(0.4)
|
||||
raise RuntimeError("api_server readiness probe timed out")
|
||||
|
||||
|
||||
def launch_openai_server(model: str = DEFAULT_MODEL, **kw):
|
||||
"""Spawn the draft OpenAI-compatible server and wait until it's ready."""
|
||||
port = _pick_free_port()
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
SERVER_MODULE,
|
||||
"--model-path",
|
||||
model,
|
||||
"--host",
|
||||
"127.0.0.1",
|
||||
"--port",
|
||||
str(port),
|
||||
*map(str, kw.get("args", [])),
|
||||
]
|
||||
env = {**os.environ, **kw.get("env", {})}
|
||||
|
||||
# Write logs to a temp file so the child never blocks on a full pipe.
|
||||
log_file = tempfile.NamedTemporaryFile("w+", delete=False)
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
env=env,
|
||||
stdout=log_file,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
)
|
||||
|
||||
base = f"http://127.0.0.1:{port}"
|
||||
try:
|
||||
_wait_until_healthy(proc, base, STARTUP_TIMEOUT)
|
||||
except Exception as e:
|
||||
proc.terminate()
|
||||
proc.wait(5)
|
||||
log_file.seek(0)
|
||||
print("\n--- api_server log ---\n", log_file.read(), file=sys.stderr)
|
||||
raise e
|
||||
return proc, base, log_file
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def openai_server() -> Generator[str, None, None]:
|
||||
"""PyTest fixture that provides the server's base URL and cleans up."""
|
||||
proc, base, log_file = launch_openai_server()
|
||||
yield base
|
||||
kill_process_tree(proc.pid)
|
||||
log_file.close()
|
||||
@@ -67,29 +67,6 @@ from sglang.srt.entrypoints.openai.protocol import (
|
||||
class TestModelCard(unittest.TestCase):
|
||||
"""Test ModelCard protocol model"""
|
||||
|
||||
def test_basic_model_card_creation(self):
|
||||
"""Test basic model card creation with required fields"""
|
||||
card = ModelCard(id="test-model")
|
||||
self.assertEqual(card.id, "test-model")
|
||||
self.assertEqual(card.object, "model")
|
||||
self.assertEqual(card.owned_by, "sglang")
|
||||
self.assertIsInstance(card.created, int)
|
||||
self.assertIsNone(card.root)
|
||||
self.assertIsNone(card.max_model_len)
|
||||
|
||||
def test_model_card_with_optional_fields(self):
|
||||
"""Test model card with optional fields"""
|
||||
card = ModelCard(
|
||||
id="test-model",
|
||||
root="/path/to/model",
|
||||
max_model_len=2048,
|
||||
created=1234567890,
|
||||
)
|
||||
self.assertEqual(card.id, "test-model")
|
||||
self.assertEqual(card.root, "/path/to/model")
|
||||
self.assertEqual(card.max_model_len, 2048)
|
||||
self.assertEqual(card.created, 1234567890)
|
||||
|
||||
def test_model_card_serialization(self):
|
||||
"""Test model card JSON serialization"""
|
||||
card = ModelCard(id="test-model", max_model_len=4096)
|
||||
@@ -120,53 +97,6 @@ class TestModelList(unittest.TestCase):
|
||||
self.assertEqual(model_list.data[1].id, "model-2")
|
||||
|
||||
|
||||
class TestErrorResponse(unittest.TestCase):
|
||||
"""Test ErrorResponse protocol model"""
|
||||
|
||||
def test_basic_error_response(self):
|
||||
"""Test basic error response creation"""
|
||||
error = ErrorResponse(
|
||||
message="Invalid request", type="BadRequestError", code=400
|
||||
)
|
||||
self.assertEqual(error.object, "error")
|
||||
self.assertEqual(error.message, "Invalid request")
|
||||
self.assertEqual(error.type, "BadRequestError")
|
||||
self.assertEqual(error.code, 400)
|
||||
self.assertIsNone(error.param)
|
||||
|
||||
def test_error_response_with_param(self):
|
||||
"""Test error response with parameter"""
|
||||
error = ErrorResponse(
|
||||
message="Invalid temperature",
|
||||
type="ValidationError",
|
||||
code=422,
|
||||
param="temperature",
|
||||
)
|
||||
self.assertEqual(error.param, "temperature")
|
||||
|
||||
|
||||
class TestUsageInfo(unittest.TestCase):
|
||||
"""Test UsageInfo protocol model"""
|
||||
|
||||
def test_basic_usage_info(self):
|
||||
"""Test basic usage info creation"""
|
||||
usage = UsageInfo(prompt_tokens=10, completion_tokens=20, total_tokens=30)
|
||||
self.assertEqual(usage.prompt_tokens, 10)
|
||||
self.assertEqual(usage.completion_tokens, 20)
|
||||
self.assertEqual(usage.total_tokens, 30)
|
||||
self.assertIsNone(usage.prompt_tokens_details)
|
||||
|
||||
def test_usage_info_with_cache_details(self):
|
||||
"""Test usage info with cache details"""
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=10,
|
||||
completion_tokens=20,
|
||||
total_tokens=30,
|
||||
prompt_tokens_details={"cached_tokens": 5},
|
||||
)
|
||||
self.assertEqual(usage.prompt_tokens_details, {"cached_tokens": 5})
|
||||
|
||||
|
||||
class TestCompletionRequest(unittest.TestCase):
|
||||
"""Test CompletionRequest protocol model"""
|
||||
|
||||
@@ -181,30 +111,6 @@ class TestCompletionRequest(unittest.TestCase):
|
||||
self.assertFalse(request.stream) # default
|
||||
self.assertFalse(request.echo) # default
|
||||
|
||||
def test_completion_request_with_options(self):
|
||||
"""Test completion request with various options"""
|
||||
request = CompletionRequest(
|
||||
model="test-model",
|
||||
prompt=["Hello", "world"],
|
||||
max_tokens=100,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
n=2,
|
||||
stream=True,
|
||||
echo=True,
|
||||
stop=[".", "!"],
|
||||
logprobs=5,
|
||||
)
|
||||
self.assertEqual(request.prompt, ["Hello", "world"])
|
||||
self.assertEqual(request.max_tokens, 100)
|
||||
self.assertEqual(request.temperature, 0.7)
|
||||
self.assertEqual(request.top_p, 0.9)
|
||||
self.assertEqual(request.n, 2)
|
||||
self.assertTrue(request.stream)
|
||||
self.assertTrue(request.echo)
|
||||
self.assertEqual(request.stop, [".", "!"])
|
||||
self.assertEqual(request.logprobs, 5)
|
||||
|
||||
def test_completion_request_sglang_extensions(self):
|
||||
"""Test completion request with SGLang-specific extensions"""
|
||||
request = CompletionRequest(
|
||||
@@ -233,26 +139,6 @@ class TestCompletionRequest(unittest.TestCase):
|
||||
CompletionRequest(model="test-model") # missing prompt
|
||||
|
||||
|
||||
class TestCompletionResponse(unittest.TestCase):
|
||||
"""Test CompletionResponse protocol model"""
|
||||
|
||||
def test_basic_completion_response(self):
|
||||
"""Test basic completion response"""
|
||||
choice = CompletionResponseChoice(
|
||||
index=0, text="Hello world!", finish_reason="stop"
|
||||
)
|
||||
usage = UsageInfo(prompt_tokens=2, completion_tokens=3, total_tokens=5)
|
||||
response = CompletionResponse(
|
||||
id="test-id", model="test-model", choices=[choice], usage=usage
|
||||
)
|
||||
self.assertEqual(response.id, "test-id")
|
||||
self.assertEqual(response.object, "text_completion")
|
||||
self.assertEqual(response.model, "test-model")
|
||||
self.assertEqual(len(response.choices), 1)
|
||||
self.assertEqual(response.choices[0].text, "Hello world!")
|
||||
self.assertEqual(response.usage.total_tokens, 5)
|
||||
|
||||
|
||||
class TestChatCompletionRequest(unittest.TestCase):
|
||||
"""Test ChatCompletionRequest protocol model"""
|
||||
|
||||
@@ -268,48 +154,6 @@ class TestChatCompletionRequest(unittest.TestCase):
|
||||
self.assertFalse(request.stream) # default
|
||||
self.assertEqual(request.tool_choice, "none") # default when no tools
|
||||
|
||||
def test_chat_completion_with_multimodal_content(self):
|
||||
"""Test chat completion with multimodal content"""
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:image/jpeg;base64,/9j/4AAQ..."},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
request = ChatCompletionRequest(model="test-model", messages=messages)
|
||||
self.assertEqual(len(request.messages[0].content), 2)
|
||||
self.assertEqual(request.messages[0].content[0].type, "text")
|
||||
self.assertEqual(request.messages[0].content[1].type, "image_url")
|
||||
|
||||
def test_chat_completion_with_tools(self):
|
||||
"""Test chat completion with tools"""
|
||||
messages = [{"role": "user", "content": "What's the weather?"}]
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get weather information",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"location": {"type": "string"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
request = ChatCompletionRequest(
|
||||
model="test-model", messages=messages, tools=tools
|
||||
)
|
||||
self.assertEqual(len(request.tools), 1)
|
||||
self.assertEqual(request.tools[0].function.name, "get_weather")
|
||||
self.assertEqual(request.tool_choice, "auto") # default when tools present
|
||||
|
||||
def test_chat_completion_tool_choice_validation(self):
|
||||
"""Test tool choice validation logic"""
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
@@ -349,289 +193,6 @@ class TestChatCompletionRequest(unittest.TestCase):
|
||||
self.assertEqual(request.chat_template_kwargs, {"custom_param": "value"})
|
||||
|
||||
|
||||
class TestChatCompletionResponse(unittest.TestCase):
|
||||
"""Test ChatCompletionResponse protocol model"""
|
||||
|
||||
def test_basic_chat_completion_response(self):
|
||||
"""Test basic chat completion response"""
|
||||
message = ChatMessage(role="assistant", content="Hello there!")
|
||||
choice = ChatCompletionResponseChoice(
|
||||
index=0, message=message, finish_reason="stop"
|
||||
)
|
||||
usage = UsageInfo(prompt_tokens=2, completion_tokens=3, total_tokens=5)
|
||||
response = ChatCompletionResponse(
|
||||
id="test-id", model="test-model", choices=[choice], usage=usage
|
||||
)
|
||||
self.assertEqual(response.id, "test-id")
|
||||
self.assertEqual(response.object, "chat.completion")
|
||||
self.assertEqual(response.model, "test-model")
|
||||
self.assertEqual(len(response.choices), 1)
|
||||
self.assertEqual(response.choices[0].message.content, "Hello there!")
|
||||
|
||||
def test_chat_completion_response_with_tool_calls(self):
|
||||
"""Test chat completion response with tool calls"""
|
||||
tool_call = ToolCall(
|
||||
id="call_123",
|
||||
function=FunctionResponse(
|
||||
name="get_weather", arguments='{"location": "San Francisco"}'
|
||||
),
|
||||
)
|
||||
message = ChatMessage(role="assistant", content=None, tool_calls=[tool_call])
|
||||
choice = ChatCompletionResponseChoice(
|
||||
index=0, message=message, finish_reason="tool_calls"
|
||||
)
|
||||
usage = UsageInfo(prompt_tokens=10, completion_tokens=5, total_tokens=15)
|
||||
response = ChatCompletionResponse(
|
||||
id="test-id", model="test-model", choices=[choice], usage=usage
|
||||
)
|
||||
self.assertEqual(
|
||||
response.choices[0].message.tool_calls[0].function.name, "get_weather"
|
||||
)
|
||||
self.assertEqual(response.choices[0].finish_reason, "tool_calls")
|
||||
|
||||
|
||||
class TestEmbeddingRequest(unittest.TestCase):
|
||||
"""Test EmbeddingRequest protocol model"""
|
||||
|
||||
def test_basic_embedding_request(self):
|
||||
"""Test basic embedding request"""
|
||||
request = EmbeddingRequest(model="test-model", input="Hello world")
|
||||
self.assertEqual(request.model, "test-model")
|
||||
self.assertEqual(request.input, "Hello world")
|
||||
self.assertEqual(request.encoding_format, "float") # default
|
||||
self.assertIsNone(request.dimensions) # default
|
||||
|
||||
def test_embedding_request_with_list_input(self):
|
||||
"""Test embedding request with list input"""
|
||||
request = EmbeddingRequest(
|
||||
model="test-model", input=["Hello", "world"], dimensions=512
|
||||
)
|
||||
self.assertEqual(request.input, ["Hello", "world"])
|
||||
self.assertEqual(request.dimensions, 512)
|
||||
|
||||
def test_multimodal_embedding_request(self):
|
||||
"""Test multimodal embedding request"""
|
||||
multimodal_input = [
|
||||
MultimodalEmbeddingInput(text="Hello", image="base64_image_data"),
|
||||
MultimodalEmbeddingInput(text="World", image=None),
|
||||
]
|
||||
request = EmbeddingRequest(model="test-model", input=multimodal_input)
|
||||
self.assertEqual(len(request.input), 2)
|
||||
self.assertEqual(request.input[0].text, "Hello")
|
||||
self.assertEqual(request.input[0].image, "base64_image_data")
|
||||
self.assertEqual(request.input[1].text, "World")
|
||||
self.assertIsNone(request.input[1].image)
|
||||
|
||||
|
||||
class TestEmbeddingResponse(unittest.TestCase):
|
||||
"""Test EmbeddingResponse protocol model"""
|
||||
|
||||
def test_basic_embedding_response(self):
|
||||
"""Test basic embedding response"""
|
||||
embedding_obj = EmbeddingObject(embedding=[0.1, 0.2, 0.3], index=0)
|
||||
usage = UsageInfo(prompt_tokens=3, total_tokens=3)
|
||||
response = EmbeddingResponse(
|
||||
data=[embedding_obj], model="test-model", usage=usage
|
||||
)
|
||||
self.assertEqual(response.object, "list")
|
||||
self.assertEqual(len(response.data), 1)
|
||||
self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3])
|
||||
self.assertEqual(response.data[0].index, 0)
|
||||
self.assertEqual(response.usage.prompt_tokens, 3)
|
||||
|
||||
|
||||
class TestScoringRequest(unittest.TestCase):
|
||||
"""Test ScoringRequest protocol model"""
|
||||
|
||||
def test_basic_scoring_request(self):
|
||||
"""Test basic scoring request"""
|
||||
request = ScoringRequest(
|
||||
model="test-model", query="Hello", items=["World", "Earth"]
|
||||
)
|
||||
self.assertEqual(request.model, "test-model")
|
||||
self.assertEqual(request.query, "Hello")
|
||||
self.assertEqual(request.items, ["World", "Earth"])
|
||||
self.assertFalse(request.apply_softmax) # default
|
||||
self.assertFalse(request.item_first) # default
|
||||
|
||||
def test_scoring_request_with_token_ids(self):
|
||||
"""Test scoring request with token IDs"""
|
||||
request = ScoringRequest(
|
||||
model="test-model",
|
||||
query=[1, 2, 3],
|
||||
items=[[4, 5], [6, 7]],
|
||||
label_token_ids=[8, 9],
|
||||
apply_softmax=True,
|
||||
item_first=True,
|
||||
)
|
||||
self.assertEqual(request.query, [1, 2, 3])
|
||||
self.assertEqual(request.items, [[4, 5], [6, 7]])
|
||||
self.assertEqual(request.label_token_ids, [8, 9])
|
||||
self.assertTrue(request.apply_softmax)
|
||||
self.assertTrue(request.item_first)
|
||||
|
||||
|
||||
class TestScoringResponse(unittest.TestCase):
|
||||
"""Test ScoringResponse protocol model"""
|
||||
|
||||
def test_basic_scoring_response(self):
|
||||
"""Test basic scoring response"""
|
||||
response = ScoringResponse(scores=[[0.1, 0.9], [0.3, 0.7]], model="test-model")
|
||||
self.assertEqual(response.object, "scoring")
|
||||
self.assertEqual(response.scores, [[0.1, 0.9], [0.3, 0.7]])
|
||||
self.assertEqual(response.model, "test-model")
|
||||
self.assertIsNone(response.usage) # default
|
||||
|
||||
|
||||
class TestFileOperations(unittest.TestCase):
|
||||
"""Test file operation protocol models"""
|
||||
|
||||
def test_file_request(self):
|
||||
"""Test file request model"""
|
||||
file_data = b"test file content"
|
||||
request = FileRequest(file=file_data, purpose="batch")
|
||||
self.assertEqual(request.file, file_data)
|
||||
self.assertEqual(request.purpose, "batch")
|
||||
|
||||
def test_file_response(self):
|
||||
"""Test file response model"""
|
||||
response = FileResponse(
|
||||
id="file-123",
|
||||
bytes=1024,
|
||||
created_at=1234567890,
|
||||
filename="test.jsonl",
|
||||
purpose="batch",
|
||||
)
|
||||
self.assertEqual(response.id, "file-123")
|
||||
self.assertEqual(response.object, "file")
|
||||
self.assertEqual(response.bytes, 1024)
|
||||
self.assertEqual(response.filename, "test.jsonl")
|
||||
|
||||
def test_file_delete_response(self):
|
||||
"""Test file delete response model"""
|
||||
response = FileDeleteResponse(id="file-123", deleted=True)
|
||||
self.assertEqual(response.id, "file-123")
|
||||
self.assertEqual(response.object, "file")
|
||||
self.assertTrue(response.deleted)
|
||||
|
||||
|
||||
class TestBatchOperations(unittest.TestCase):
|
||||
"""Test batch operation protocol models"""
|
||||
|
||||
def test_batch_request(self):
|
||||
"""Test batch request model"""
|
||||
request = BatchRequest(
|
||||
input_file_id="file-123",
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
metadata={"custom": "value"},
|
||||
)
|
||||
self.assertEqual(request.input_file_id, "file-123")
|
||||
self.assertEqual(request.endpoint, "/v1/chat/completions")
|
||||
self.assertEqual(request.completion_window, "24h")
|
||||
self.assertEqual(request.metadata, {"custom": "value"})
|
||||
|
||||
def test_batch_response(self):
|
||||
"""Test batch response model"""
|
||||
response = BatchResponse(
|
||||
id="batch-123",
|
||||
endpoint="/v1/chat/completions",
|
||||
input_file_id="file-123",
|
||||
completion_window="24h",
|
||||
created_at=1234567890,
|
||||
)
|
||||
self.assertEqual(response.id, "batch-123")
|
||||
self.assertEqual(response.object, "batch")
|
||||
self.assertEqual(response.status, "validating") # default
|
||||
self.assertEqual(response.endpoint, "/v1/chat/completions")
|
||||
|
||||
|
||||
class TestResponseFormats(unittest.TestCase):
|
||||
"""Test response format protocol models"""
|
||||
|
||||
def test_basic_response_format(self):
|
||||
"""Test basic response format"""
|
||||
format_obj = ResponseFormat(type="json_object")
|
||||
self.assertEqual(format_obj.type, "json_object")
|
||||
self.assertIsNone(format_obj.json_schema)
|
||||
|
||||
def test_json_schema_response_format(self):
|
||||
"""Test JSON schema response format"""
|
||||
schema = {"type": "object", "properties": {"name": {"type": "string"}}}
|
||||
json_schema = JsonSchemaResponseFormat(
|
||||
name="person_schema", description="Person schema", schema=schema
|
||||
)
|
||||
format_obj = ResponseFormat(type="json_schema", json_schema=json_schema)
|
||||
self.assertEqual(format_obj.type, "json_schema")
|
||||
self.assertEqual(format_obj.json_schema.name, "person_schema")
|
||||
self.assertEqual(format_obj.json_schema.schema_, schema)
|
||||
|
||||
def test_structural_tag_response_format(self):
|
||||
"""Test structural tag response format"""
|
||||
structures = [
|
||||
{
|
||||
"begin": "<thinking>",
|
||||
"schema_": {"type": "string"},
|
||||
"end": "</thinking>",
|
||||
}
|
||||
]
|
||||
format_obj = StructuralTagResponseFormat(
|
||||
type="structural_tag", structures=structures, triggers=["think"]
|
||||
)
|
||||
self.assertEqual(format_obj.type, "structural_tag")
|
||||
self.assertEqual(len(format_obj.structures), 1)
|
||||
self.assertEqual(format_obj.triggers, ["think"])
|
||||
|
||||
|
||||
class TestLogProbs(unittest.TestCase):
|
||||
"""Test LogProbs protocol models"""
|
||||
|
||||
def test_basic_logprobs(self):
|
||||
"""Test basic LogProbs model"""
|
||||
logprobs = LogProbs(
|
||||
text_offset=[0, 5, 11],
|
||||
token_logprobs=[-0.1, -0.2, -0.3],
|
||||
tokens=["Hello", " ", "world"],
|
||||
top_logprobs=[{"Hello": -0.1}, {" ": -0.2}, {"world": -0.3}],
|
||||
)
|
||||
self.assertEqual(len(logprobs.tokens), 3)
|
||||
self.assertEqual(logprobs.tokens, ["Hello", " ", "world"])
|
||||
self.assertEqual(logprobs.token_logprobs, [-0.1, -0.2, -0.3])
|
||||
|
||||
def test_choice_logprobs(self):
|
||||
"""Test ChoiceLogprobs model"""
|
||||
token_logprob = ChatCompletionTokenLogprob(
|
||||
token="Hello",
|
||||
bytes=[72, 101, 108, 108, 111],
|
||||
logprob=-0.1,
|
||||
top_logprobs=[
|
||||
TopLogprob(token="Hello", bytes=[72, 101, 108, 108, 111], logprob=-0.1)
|
||||
],
|
||||
)
|
||||
choice_logprobs = ChoiceLogprobs(content=[token_logprob])
|
||||
self.assertEqual(len(choice_logprobs.content), 1)
|
||||
self.assertEqual(choice_logprobs.content[0].token, "Hello")
|
||||
|
||||
|
||||
class TestStreamingModels(unittest.TestCase):
|
||||
"""Test streaming response models"""
|
||||
|
||||
def test_stream_options(self):
|
||||
"""Test StreamOptions model"""
|
||||
options = StreamOptions(include_usage=True)
|
||||
self.assertTrue(options.include_usage)
|
||||
|
||||
def test_chat_completion_stream_response(self):
|
||||
"""Test ChatCompletionStreamResponse model"""
|
||||
delta = DeltaMessage(role="assistant", content="Hello")
|
||||
choice = ChatCompletionResponseStreamChoice(index=0, delta=delta)
|
||||
response = ChatCompletionStreamResponse(
|
||||
id="test-id", model="test-model", choices=[choice]
|
||||
)
|
||||
self.assertEqual(response.object, "chat.completion.chunk")
|
||||
self.assertEqual(response.choices[0].delta.content, "Hello")
|
||||
|
||||
|
||||
class TestModelSerialization(unittest.TestCase):
|
||||
"""Test model serialization with hidden states"""
|
||||
|
||||
@@ -680,11 +241,6 @@ class TestModelSerialization(unittest.TestCase):
|
||||
class TestValidationEdgeCases(unittest.TestCase):
|
||||
"""Test edge cases and validation scenarios"""
|
||||
|
||||
def test_empty_messages_validation(self):
|
||||
"""Test validation with empty messages"""
|
||||
with self.assertRaises(ValidationError):
|
||||
ChatCompletionRequest(model="test-model", messages=[])
|
||||
|
||||
def test_invalid_tool_choice_type(self):
|
||||
"""Test invalid tool choice type"""
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
@@ -698,13 +254,6 @@ class TestValidationEdgeCases(unittest.TestCase):
|
||||
with self.assertRaises(ValidationError):
|
||||
CompletionRequest(model="test-model", prompt="Hello", max_tokens=-1)
|
||||
|
||||
def test_invalid_temperature_range(self):
|
||||
"""Test invalid temperature values"""
|
||||
# Note: The current protocol doesn't enforce temperature range,
|
||||
# but this test documents expected behavior
|
||||
request = CompletionRequest(model="test-model", prompt="Hello", temperature=5.0)
|
||||
self.assertEqual(request.temperature, 5.0) # Currently allowed
|
||||
|
||||
def test_model_serialization_roundtrip(self):
|
||||
"""Test that models can be serialized and deserialized"""
|
||||
original_request = ChatCompletionRequest(
|
||||
|
||||
@@ -1,52 +0,0 @@
|
||||
# sglang/test/srt/openai/test_server.py
|
||||
import requests
|
||||
|
||||
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST as MODEL_ID
|
||||
|
||||
|
||||
def test_health(openai_server: str):
|
||||
r = requests.get(f"{openai_server}/health")
|
||||
assert r.status_code == 200
|
||||
# FastAPI returns an empty body → r.text == ""
|
||||
assert r.text == ""
|
||||
|
||||
|
||||
def test_models_endpoint(openai_server: str):
|
||||
r = requests.get(f"{openai_server}/v1/models")
|
||||
assert r.status_code == 200, r.text
|
||||
payload = r.json()
|
||||
|
||||
# Basic contract
|
||||
assert "data" in payload and isinstance(payload["data"], list) and payload["data"]
|
||||
|
||||
# Validate fields of the first model card
|
||||
first = payload["data"][0]
|
||||
for key in ("id", "root", "max_model_len"):
|
||||
assert key in first, f"missing {key} in {first}"
|
||||
|
||||
# max_model_len must be positive
|
||||
assert isinstance(first["max_model_len"], int) and first["max_model_len"] > 0
|
||||
|
||||
# The server should report the same model id we launched it with
|
||||
ids = {m["id"] for m in payload["data"]}
|
||||
assert MODEL_ID in ids
|
||||
|
||||
|
||||
def test_get_model_info(openai_server: str):
|
||||
r = requests.get(f"{openai_server}/get_model_info")
|
||||
assert r.status_code == 200, r.text
|
||||
info = r.json()
|
||||
|
||||
expected_keys = {"model_path", "tokenizer_path", "is_generation"}
|
||||
assert expected_keys.issubset(info.keys())
|
||||
|
||||
# model_path must end with the one we passed on the CLI
|
||||
assert info["model_path"].endswith(MODEL_ID)
|
||||
|
||||
# is_generation is documented as a boolean
|
||||
assert isinstance(info["is_generation"], bool)
|
||||
|
||||
|
||||
def test_unknown_route_returns_404(openai_server: str):
|
||||
r = requests.get(f"{openai_server}/definitely-not-a-real-route")
|
||||
assert r.status_code == 404
|
||||
@@ -57,11 +57,21 @@ class _MockTokenizerManager:
|
||||
self.create_abort_task = Mock()
|
||||
|
||||
|
||||
class _MockTemplateManager:
|
||||
"""Minimal mock for TemplateManager."""
|
||||
|
||||
def __init__(self):
|
||||
self.chat_template_name: Optional[str] = "llama-3"
|
||||
self.jinja_template_content_format: Optional[str] = None
|
||||
self.completion_template_name: Optional[str] = None
|
||||
|
||||
|
||||
class ServingChatTestCase(unittest.TestCase):
|
||||
# ------------- common fixtures -------------
|
||||
def setUp(self):
|
||||
self.tm = _MockTokenizerManager()
|
||||
self.chat = OpenAIServingChat(self.tm)
|
||||
self.template_manager = _MockTemplateManager()
|
||||
self.chat = OpenAIServingChat(self.tm, self.template_manager)
|
||||
|
||||
# frequently reused requests
|
||||
self.basic_req = ChatCompletionRequest(
|
||||
@@ -109,96 +119,6 @@ class ServingChatTestCase(unittest.TestCase):
|
||||
self.assertFalse(adapted.stream)
|
||||
self.assertEqual(processed, self.basic_req)
|
||||
|
||||
# # ------------- tool-call branch -------------
|
||||
# def test_tool_call_request_conversion(self):
|
||||
# req = ChatCompletionRequest(
|
||||
# model="x",
|
||||
# messages=[{"role": "user", "content": "Weather?"}],
|
||||
# tools=[
|
||||
# {
|
||||
# "type": "function",
|
||||
# "function": {
|
||||
# "name": "get_weather",
|
||||
# "parameters": {"type": "object", "properties": {}},
|
||||
# },
|
||||
# }
|
||||
# ],
|
||||
# tool_choice="auto",
|
||||
# )
|
||||
|
||||
# with patch.object(
|
||||
# self.chat,
|
||||
# "_process_messages",
|
||||
# return_value=("Prompt", [1, 2, 3], None, None, [], ["</s>"], None),
|
||||
# ):
|
||||
# adapted, _ = self.chat._convert_to_internal_request(req, "rid")
|
||||
# self.assertEqual(adapted.rid, "rid")
|
||||
|
||||
# def test_tool_choice_none(self):
|
||||
# req = ChatCompletionRequest(
|
||||
# model="x",
|
||||
# messages=[{"role": "user", "content": "Hi"}],
|
||||
# tools=[{"type": "function", "function": {"name": "noop"}}],
|
||||
# tool_choice="none",
|
||||
# )
|
||||
# with patch.object(
|
||||
# self.chat,
|
||||
# "_process_messages",
|
||||
# return_value=("Prompt", [1, 2, 3], None, None, [], ["</s>"], None),
|
||||
# ):
|
||||
# adapted, _ = self.chat._convert_to_internal_request(req, "rid")
|
||||
# self.assertEqual(adapted.rid, "rid")
|
||||
|
||||
# ------------- multimodal branch -------------
|
||||
def test_multimodal_request_with_images(self):
|
||||
self.tm.model_config.is_multimodal = True
|
||||
|
||||
req = ChatCompletionRequest(
|
||||
model="x",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in the image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:image/jpeg;base64,"},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
self.chat,
|
||||
"_apply_jinja_template",
|
||||
return_value=("prompt", [1, 2], ["img"], None, [], []),
|
||||
), patch.object(
|
||||
self.chat,
|
||||
"_apply_conversation_template",
|
||||
return_value=("prompt", ["img"], None, [], []),
|
||||
):
|
||||
out = self.chat._process_messages(req, True)
|
||||
_, _, image_data, *_ = out
|
||||
self.assertEqual(image_data, ["img"])
|
||||
|
||||
# ------------- template handling -------------
|
||||
def test_jinja_template_processing(self):
|
||||
req = ChatCompletionRequest(
|
||||
model="x", messages=[{"role": "user", "content": "Hello"}]
|
||||
)
|
||||
self.tm.chat_template_name = None
|
||||
self.tm.tokenizer.chat_template = "<jinja>"
|
||||
|
||||
with patch.object(
|
||||
self.chat,
|
||||
"_apply_jinja_template",
|
||||
return_value=("processed", [1], None, None, [], ["</s>"]),
|
||||
), patch("builtins.hasattr", return_value=True):
|
||||
prompt, prompt_ids, *_ = self.chat._process_messages(req, False)
|
||||
self.assertEqual(prompt, "processed")
|
||||
self.assertEqual(prompt_ids, [1])
|
||||
|
||||
# ------------- sampling-params -------------
|
||||
def test_sampling_param_build(self):
|
||||
req = ChatCompletionRequest(
|
||||
|
||||
@@ -5,6 +5,7 @@ Run with:
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from typing import Optional
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import CompletionRequest
|
||||
@@ -12,6 +13,17 @@ from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompl
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
|
||||
|
||||
class _MockTemplateManager:
|
||||
"""Minimal mock for TemplateManager."""
|
||||
|
||||
def __init__(self):
|
||||
self.chat_template_name: Optional[str] = None
|
||||
self.jinja_template_content_format: Optional[str] = None
|
||||
self.completion_template_name: Optional[str] = (
|
||||
None # Set to None to avoid template processing
|
||||
)
|
||||
|
||||
|
||||
class ServingCompletionTestCase(unittest.TestCase):
|
||||
"""Bundle all prompt/echo tests in one TestCase."""
|
||||
|
||||
@@ -31,7 +43,8 @@ class ServingCompletionTestCase(unittest.TestCase):
|
||||
tm.generate_request = AsyncMock()
|
||||
tm.create_abort_task = Mock()
|
||||
|
||||
self.sc = OpenAIServingCompletion(tm)
|
||||
self.template_manager = _MockTemplateManager()
|
||||
self.sc = OpenAIServingCompletion(tm, self.template_manager)
|
||||
|
||||
# ---------- prompt-handling ----------
|
||||
def test_single_string_prompt(self):
|
||||
@@ -44,20 +57,6 @@ class ServingCompletionTestCase(unittest.TestCase):
|
||||
internal, _ = self.sc._convert_to_internal_request(req)
|
||||
self.assertEqual(internal.input_ids, [1, 2, 3, 4])
|
||||
|
||||
def test_completion_template_handling(self):
|
||||
req = CompletionRequest(
|
||||
model="x", prompt="def f():", suffix="return 1", max_tokens=100
|
||||
)
|
||||
with patch(
|
||||
"sglang.srt.entrypoints.openai.serving_completions.is_completion_template_defined",
|
||||
return_value=True,
|
||||
), patch(
|
||||
"sglang.srt.entrypoints.openai.serving_completions.generate_completion_prompt_from_request",
|
||||
return_value="processed_prompt",
|
||||
):
|
||||
internal, _ = self.sc._convert_to_internal_request(req)
|
||||
self.assertEqual(internal.text, "processed_prompt")
|
||||
|
||||
# ---------- echo-handling ----------
|
||||
def test_echo_with_string_prompt_streaming(self):
|
||||
req = CompletionRequest(model="x", prompt="Hello", max_tokens=1, echo=True)
|
||||
|
||||
@@ -5,25 +5,16 @@ These tests ensure that the embedding serving implementation maintains compatibi
|
||||
with the original adapter.py functionality and follows OpenAI API specifications.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import unittest
|
||||
import uuid
|
||||
from typing import Any, Dict, List
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
from unittest.mock import Mock
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from pydantic_core import ValidationError
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import (
|
||||
EmbeddingObject,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
ErrorResponse,
|
||||
MultimodalEmbeddingInput,
|
||||
UsageInfo,
|
||||
)
|
||||
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
from sglang.srt.managers.io_struct import EmbeddingReqInput
|
||||
@@ -58,11 +49,22 @@ class _MockTokenizerManager:
|
||||
self.generate_request = Mock(return_value=mock_generate_embedding())
|
||||
|
||||
|
||||
# Mock TemplateManager for embedding tests
|
||||
class _MockTemplateManager:
|
||||
def __init__(self):
|
||||
self.chat_template_name = None # None for embeddings usually
|
||||
self.jinja_template_content_format = None
|
||||
self.completion_template_name = None
|
||||
|
||||
|
||||
class ServingEmbeddingTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.tokenizer_manager = _MockTokenizerManager()
|
||||
self.serving_embedding = OpenAIServingEmbedding(self.tokenizer_manager)
|
||||
self.template_manager = _MockTemplateManager()
|
||||
self.serving_embedding = OpenAIServingEmbedding(
|
||||
self.tokenizer_manager, self.template_manager
|
||||
)
|
||||
|
||||
self.request = Mock(spec=Request)
|
||||
self.request.headers = {}
|
||||
@@ -141,132 +143,6 @@ class ServingEmbeddingTestCase(unittest.TestCase):
|
||||
self.assertIsNone(adapted_request.image_data[1])
|
||||
# self.assertEqual(adapted_request.rid, "test-id")
|
||||
|
||||
def test_build_single_embedding_response(self):
|
||||
"""Test building response for single embedding."""
|
||||
ret_data = [
|
||||
{
|
||||
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
"meta_info": {"prompt_tokens": 5},
|
||||
}
|
||||
]
|
||||
|
||||
response = self.serving_embedding._build_embedding_response(ret_data)
|
||||
|
||||
self.assertIsInstance(response, EmbeddingResponse)
|
||||
self.assertEqual(response.model, "test-model")
|
||||
self.assertEqual(len(response.data), 1)
|
||||
self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3, 0.4, 0.5])
|
||||
self.assertEqual(response.data[0].index, 0)
|
||||
self.assertEqual(response.data[0].object, "embedding")
|
||||
self.assertEqual(response.usage.prompt_tokens, 5)
|
||||
self.assertEqual(response.usage.total_tokens, 5)
|
||||
self.assertEqual(response.usage.completion_tokens, 0)
|
||||
|
||||
def test_build_multiple_embedding_response(self):
|
||||
"""Test building response for multiple embeddings."""
|
||||
ret_data = [
|
||||
{
|
||||
"embedding": [0.1, 0.2, 0.3],
|
||||
"meta_info": {"prompt_tokens": 3},
|
||||
},
|
||||
{
|
||||
"embedding": [0.4, 0.5, 0.6],
|
||||
"meta_info": {"prompt_tokens": 4},
|
||||
},
|
||||
]
|
||||
|
||||
response = self.serving_embedding._build_embedding_response(ret_data)
|
||||
|
||||
self.assertIsInstance(response, EmbeddingResponse)
|
||||
self.assertEqual(len(response.data), 2)
|
||||
self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3])
|
||||
self.assertEqual(response.data[0].index, 0)
|
||||
self.assertEqual(response.data[1].embedding, [0.4, 0.5, 0.6])
|
||||
self.assertEqual(response.data[1].index, 1)
|
||||
self.assertEqual(response.usage.prompt_tokens, 7) # 3 + 4
|
||||
self.assertEqual(response.usage.total_tokens, 7)
|
||||
|
||||
def test_handle_request_success(self):
|
||||
"""Test successful embedding request handling."""
|
||||
|
||||
async def run_test():
|
||||
# Mock the generate_request to return expected data
|
||||
async def mock_generate():
|
||||
yield {
|
||||
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
"meta_info": {"prompt_tokens": 5},
|
||||
}
|
||||
|
||||
self.serving_embedding.tokenizer_manager.generate_request = Mock(
|
||||
return_value=mock_generate()
|
||||
)
|
||||
|
||||
response = await self.serving_embedding.handle_request(
|
||||
self.basic_req, self.request
|
||||
)
|
||||
|
||||
self.assertIsInstance(response, EmbeddingResponse)
|
||||
self.assertEqual(len(response.data), 1)
|
||||
self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3, 0.4, 0.5])
|
||||
|
||||
asyncio.run(run_test())
|
||||
|
||||
def test_handle_request_validation_error(self):
|
||||
"""Test handling request with validation error."""
|
||||
|
||||
async def run_test():
|
||||
invalid_request = EmbeddingRequest(model="test-model", input="")
|
||||
|
||||
response = await self.serving_embedding.handle_request(
|
||||
invalid_request, self.request
|
||||
)
|
||||
|
||||
self.assertIsInstance(response, ORJSONResponse)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
|
||||
asyncio.run(run_test())
|
||||
|
||||
def test_handle_request_generation_error(self):
|
||||
"""Test handling request with generation error."""
|
||||
|
||||
async def run_test():
|
||||
# Mock generate_request to raise an error
|
||||
async def mock_generate_error():
|
||||
raise ValueError("Generation failed")
|
||||
yield # This won't be reached but needed for async generator
|
||||
|
||||
self.serving_embedding.tokenizer_manager.generate_request = Mock(
|
||||
return_value=mock_generate_error()
|
||||
)
|
||||
|
||||
response = await self.serving_embedding.handle_request(
|
||||
self.basic_req, self.request
|
||||
)
|
||||
|
||||
self.assertIsInstance(response, ORJSONResponse)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
|
||||
asyncio.run(run_test())
|
||||
|
||||
def test_handle_request_internal_error(self):
|
||||
"""Test handling request with internal server error."""
|
||||
|
||||
async def run_test():
|
||||
# Mock _convert_to_internal_request to raise an exception
|
||||
with patch.object(
|
||||
self.serving_embedding,
|
||||
"_convert_to_internal_request",
|
||||
side_effect=Exception("Internal error"),
|
||||
):
|
||||
response = await self.serving_embedding.handle_request(
|
||||
self.basic_req, self.request
|
||||
)
|
||||
|
||||
self.assertIsInstance(response, ORJSONResponse)
|
||||
self.assertEqual(response.status_code, 500)
|
||||
|
||||
asyncio.run(run_test())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -29,6 +29,10 @@ suites = {
|
||||
TestFile("models/test_reward_models.py", 132),
|
||||
TestFile("models/test_vlm_models.py", 437),
|
||||
TestFile("models/test_transformers_models.py", 320),
|
||||
TestFile("openai/test_protocol.py", 10),
|
||||
TestFile("openai/test_serving_chat.py", 10),
|
||||
TestFile("openai/test_serving_completions.py", 10),
|
||||
TestFile("openai/test_serving_embedding.py", 10),
|
||||
TestFile("test_abort.py", 51),
|
||||
TestFile("test_block_int8.py", 22),
|
||||
TestFile("test_create_kvindices.py", 2),
|
||||
@@ -49,6 +53,7 @@ suites = {
|
||||
TestFile("test_hidden_states.py", 55),
|
||||
TestFile("test_int8_kernel.py", 8),
|
||||
TestFile("test_input_embeddings.py", 38),
|
||||
TestFile("test_jinja_template_utils.py", 1),
|
||||
TestFile("test_json_constrained.py", 98),
|
||||
TestFile("test_large_max_new_tokens.py", 41),
|
||||
TestFile("test_metrics.py", 32),
|
||||
@@ -59,14 +64,8 @@ suites = {
|
||||
TestFile("test_mla_fp8.py", 93),
|
||||
TestFile("test_no_chunked_prefill.py", 108),
|
||||
TestFile("test_no_overlap_scheduler.py", 234),
|
||||
TestFile("test_openai_adapter.py", 1),
|
||||
TestFile("test_openai_function_calling.py", 60),
|
||||
TestFile("test_openai_server.py", 149),
|
||||
TestFile("openai/test_server.py", 120),
|
||||
TestFile("openai/test_protocol.py", 60),
|
||||
TestFile("openai/test_serving_chat.py", 120),
|
||||
TestFile("openai/test_serving_completions.py", 120),
|
||||
TestFile("openai/test_serving_embedding.py", 120),
|
||||
TestFile("test_openai_server_hidden_states.py", 240),
|
||||
TestFile("test_penalty.py", 41),
|
||||
TestFile("test_page_size.py", 60),
|
||||
|
||||
@@ -3,6 +3,7 @@ import unittest
|
||||
|
||||
from xgrammar import GrammarCompiler, TokenizerInfo
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import Function, Tool
|
||||
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
|
||||
from sglang.srt.function_call.llama32_detector import Llama32Detector
|
||||
@@ -10,7 +11,6 @@ from sglang.srt.function_call.mistral_detector import MistralDetector
|
||||
from sglang.srt.function_call.pythonic_detector import PythonicDetector
|
||||
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.openai_api.protocol import Function, Tool
|
||||
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
|
||||
|
||||
|
||||
@@ -5,8 +5,8 @@ Unit tests for OpenAI adapter utils.
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from sglang.srt.openai_api.utils import (
|
||||
detect_template_content_format,
|
||||
from sglang.srt.jinja_template_utils import (
|
||||
detect_jinja_template_content_format,
|
||||
process_content_for_template_format,
|
||||
)
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
@@ -33,7 +33,7 @@ class TestTemplateContentFormatDetection(CustomTestCase):
|
||||
{%- endfor %}
|
||||
"""
|
||||
|
||||
result = detect_template_content_format(llama4_pattern)
|
||||
result = detect_jinja_template_content_format(llama4_pattern)
|
||||
self.assertEqual(result, "openai")
|
||||
|
||||
def test_detect_deepseek_string_format(self):
|
||||
@@ -46,19 +46,19 @@ class TestTemplateContentFormatDetection(CustomTestCase):
|
||||
{%- endfor %}
|
||||
"""
|
||||
|
||||
result = detect_template_content_format(deepseek_pattern)
|
||||
result = detect_jinja_template_content_format(deepseek_pattern)
|
||||
self.assertEqual(result, "string")
|
||||
|
||||
def test_detect_invalid_template(self):
|
||||
"""Test handling of invalid template (should default to 'string')."""
|
||||
invalid_pattern = "{{{{ invalid jinja syntax }}}}"
|
||||
|
||||
result = detect_template_content_format(invalid_pattern)
|
||||
result = detect_jinja_template_content_format(invalid_pattern)
|
||||
self.assertEqual(result, "string")
|
||||
|
||||
def test_detect_empty_template(self):
|
||||
"""Test handling of empty template (should default to 'string')."""
|
||||
result = detect_template_content_format("")
|
||||
result = detect_jinja_template_content_format("")
|
||||
self.assertEqual(result, "string")
|
||||
|
||||
def test_process_content_openai_format(self):
|
||||
@@ -235,6 +235,7 @@ class TestOpenAIServer(CustomTestCase):
|
||||
)
|
||||
|
||||
is_firsts = {}
|
||||
is_finished = {}
|
||||
for response in generator:
|
||||
usage = response.usage
|
||||
if usage is not None:
|
||||
@@ -244,6 +245,10 @@ class TestOpenAIServer(CustomTestCase):
|
||||
continue
|
||||
|
||||
index = response.choices[0].index
|
||||
finish_reason = response.choices[0].finish_reason
|
||||
if finish_reason is not None:
|
||||
is_finished[index] = True
|
||||
|
||||
data = response.choices[0].delta
|
||||
|
||||
if is_firsts.get(index, True):
|
||||
@@ -253,7 +258,7 @@ class TestOpenAIServer(CustomTestCase):
|
||||
is_firsts[index] = False
|
||||
continue
|
||||
|
||||
if logprobs:
|
||||
if logprobs and not is_finished.get(index, False):
|
||||
assert response.choices[0].logprobs, f"logprobs was not returned"
|
||||
assert isinstance(
|
||||
response.choices[0].logprobs.content[0].top_logprobs[0].token, str
|
||||
@@ -271,7 +276,7 @@ class TestOpenAIServer(CustomTestCase):
|
||||
assert (
|
||||
isinstance(data.content, str)
|
||||
or isinstance(data.reasoning_content, str)
|
||||
or len(data.tool_calls) > 0
|
||||
or (isinstance(data.tool_calls, list) and len(data.tool_calls) > 0)
|
||||
or response.choices[0].finish_reason
|
||||
)
|
||||
assert response.id
|
||||
@@ -282,152 +287,6 @@ class TestOpenAIServer(CustomTestCase):
|
||||
index, True
|
||||
), f"index {index} is not found in the response"
|
||||
|
||||
def _create_batch(self, mode, client):
|
||||
if mode == "completion":
|
||||
input_file_path = "complete_input.jsonl"
|
||||
# write content to input file
|
||||
content = [
|
||||
{
|
||||
"custom_id": "request-1",
|
||||
"method": "POST",
|
||||
"url": "/v1/completions",
|
||||
"body": {
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"prompt": "List 3 names of famous soccer player: ",
|
||||
"max_tokens": 20,
|
||||
},
|
||||
},
|
||||
{
|
||||
"custom_id": "request-2",
|
||||
"method": "POST",
|
||||
"url": "/v1/completions",
|
||||
"body": {
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"prompt": "List 6 names of famous basketball player: ",
|
||||
"max_tokens": 40,
|
||||
},
|
||||
},
|
||||
{
|
||||
"custom_id": "request-3",
|
||||
"method": "POST",
|
||||
"url": "/v1/completions",
|
||||
"body": {
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"prompt": "List 6 names of famous tenniss player: ",
|
||||
"max_tokens": 40,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
else:
|
||||
input_file_path = "chat_input.jsonl"
|
||||
content = [
|
||||
{
|
||||
"custom_id": "request-1",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": "gpt-3.5-turbo-0125",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello! List 3 NBA players and tell a story",
|
||||
},
|
||||
],
|
||||
"max_tokens": 30,
|
||||
},
|
||||
},
|
||||
{
|
||||
"custom_id": "request-2",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": "gpt-3.5-turbo-0125",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are an assistant. "},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello! List three capital and tell a story",
|
||||
},
|
||||
],
|
||||
"max_tokens": 50,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
with open(input_file_path, "w") as file:
|
||||
for line in content:
|
||||
file.write(json.dumps(line) + "\n")
|
||||
|
||||
with open(input_file_path, "rb") as file:
|
||||
uploaded_file = client.files.create(file=file, purpose="batch")
|
||||
if mode == "completion":
|
||||
endpoint = "/v1/completions"
|
||||
elif mode == "chat":
|
||||
endpoint = "/v1/chat/completions"
|
||||
completion_window = "24h"
|
||||
batch_job = client.batches.create(
|
||||
input_file_id=uploaded_file.id,
|
||||
endpoint=endpoint,
|
||||
completion_window=completion_window,
|
||||
)
|
||||
|
||||
return batch_job, content, uploaded_file
|
||||
|
||||
def run_batch(self, mode):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
batch_job, content, uploaded_file = self._create_batch(mode=mode, client=client)
|
||||
|
||||
while batch_job.status not in ["completed", "failed", "cancelled"]:
|
||||
time.sleep(3)
|
||||
print(
|
||||
f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
|
||||
)
|
||||
batch_job = client.batches.retrieve(batch_job.id)
|
||||
assert (
|
||||
batch_job.status == "completed"
|
||||
), f"Batch job status is not completed: {batch_job.status}"
|
||||
assert batch_job.request_counts.completed == len(content)
|
||||
assert batch_job.request_counts.failed == 0
|
||||
assert batch_job.request_counts.total == len(content)
|
||||
|
||||
result_file_id = batch_job.output_file_id
|
||||
file_response = client.files.content(result_file_id)
|
||||
result_content = file_response.read().decode("utf-8") # Decode bytes to string
|
||||
results = [
|
||||
json.loads(line)
|
||||
for line in result_content.split("\n")
|
||||
if line.strip() != ""
|
||||
]
|
||||
assert len(results) == len(content)
|
||||
for delete_fid in [uploaded_file.id, result_file_id]:
|
||||
del_pesponse = client.files.delete(delete_fid)
|
||||
assert del_pesponse.deleted
|
||||
|
||||
def run_cancel_batch(self, mode):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
batch_job, _, uploaded_file = self._create_batch(mode=mode, client=client)
|
||||
|
||||
assert batch_job.status not in ["cancelling", "cancelled"]
|
||||
|
||||
batch_job = client.batches.cancel(batch_id=batch_job.id)
|
||||
assert batch_job.status == "cancelling"
|
||||
|
||||
while batch_job.status not in ["failed", "cancelled"]:
|
||||
batch_job = client.batches.retrieve(batch_job.id)
|
||||
print(
|
||||
f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
|
||||
)
|
||||
time.sleep(3)
|
||||
|
||||
assert batch_job.status == "cancelled"
|
||||
del_response = client.files.delete(uploaded_file.id)
|
||||
assert del_response.deleted
|
||||
|
||||
def test_completion(self):
|
||||
for echo in [False, True]:
|
||||
for logprobs in [None, 5]:
|
||||
@@ -467,14 +326,6 @@ class TestOpenAIServer(CustomTestCase):
|
||||
for parallel_sample_num in [1, 2]:
|
||||
self.run_chat_completion_stream(logprobs, parallel_sample_num)
|
||||
|
||||
def test_batch(self):
|
||||
for mode in ["completion", "chat"]:
|
||||
self.run_batch(mode)
|
||||
|
||||
def test_cancel_batch(self):
|
||||
for mode in ["completion", "chat"]:
|
||||
self.run_cancel_batch(mode)
|
||||
|
||||
def test_regex(self):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
@@ -559,6 +410,18 @@ The SmartHome Mini is a compact smart home assistant available in black or white
|
||||
assert len(models) == 1
|
||||
assert isinstance(getattr(models[0], "max_model_len", None), int)
|
||||
|
||||
def test_retrieve_model(self):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
# Test retrieving an existing model
|
||||
retrieved_model = client.models.retrieve(self.model)
|
||||
self.assertEqual(retrieved_model.id, self.model)
|
||||
self.assertEqual(retrieved_model.root, self.model)
|
||||
|
||||
# Test retrieving a non-existent model
|
||||
with self.assertRaises(openai.NotFoundError):
|
||||
client.models.retrieve("non-existent-model")
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# EBNF Test Class: TestOpenAIServerEBNF
|
||||
@@ -684,6 +547,31 @@ class TestOpenAIEmbedding(CustomTestCase):
|
||||
self.assertTrue(len(response.data[0].embedding) > 0)
|
||||
self.assertTrue(len(response.data[1].embedding) > 0)
|
||||
|
||||
def test_embedding_single_batch_str(self):
|
||||
"""Test embedding with a List[str] and length equals to 1"""
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
response = client.embeddings.create(model=self.model, input=["Hello world"])
|
||||
self.assertEqual(len(response.data), 1)
|
||||
self.assertTrue(len(response.data[0].embedding) > 0)
|
||||
|
||||
def test_embedding_single_int_list(self):
|
||||
"""Test embedding with a List[int] or List[List[int]]]"""
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
response = client.embeddings.create(
|
||||
model=self.model,
|
||||
input=[[15339, 314, 703, 284, 612, 262, 10658, 10188, 286, 2061]],
|
||||
)
|
||||
self.assertEqual(len(response.data), 1)
|
||||
self.assertTrue(len(response.data[0].embedding) > 0)
|
||||
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
response = client.embeddings.create(
|
||||
model=self.model,
|
||||
input=[15339, 314, 703, 284, 612, 262, 10658, 10188, 286, 2061],
|
||||
)
|
||||
self.assertEqual(len(response.data), 1)
|
||||
self.assertTrue(len(response.data[0].embedding) > 0)
|
||||
|
||||
def test_empty_string_embedding(self):
|
||||
"""Test embedding an empty string."""
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ from transformers import (
|
||||
from sglang import Engine
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.conversation import generate_chat_conv
|
||||
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
from sglang.srt.managers.mm_utils import embed_mm_inputs, init_embedding_cache
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
@@ -31,7 +32,6 @@ from sglang.srt.managers.schedule_batch import (
|
||||
MultimodalInputs,
|
||||
)
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.openai_api.protocol import ChatCompletionRequest
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ from transformers import (
|
||||
|
||||
from sglang import Engine
|
||||
from sglang.srt.conversation import generate_chat_conv
|
||||
from sglang.srt.openai_api.protocol import ChatCompletionRequest
|
||||
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
|
||||
TEST_IMAGE_URL = "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user