diff --git a/python/sglang/srt/entrypoints/openai/api_server.py b/python/sglang/srt/entrypoints/openai/api_server.py index a06973680..490e4ac13 100644 --- a/python/sglang/srt/entrypoints/openai/api_server.py +++ b/python/sglang/srt/entrypoints/openai/api_server.py @@ -36,7 +36,7 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import Response from sglang.srt.disaggregation.utils import ( - FakeBootstrapHost, + FAKE_BOOTSTRAP_HOST, register_disaggregation_server, ) from sglang.srt.entrypoints.engine import Engine, _launch_subprocesses @@ -265,7 +265,7 @@ def _wait_and_warmup( "max_new_tokens": 8, "ignore_eos": True, }, - "bootstrap_host": [FakeBootstrapHost] * server_args.dp_size, + "bootstrap_host": [FAKE_BOOTSTRAP_HOST] * server_args.dp_size, # This is a hack to ensure fake transfer is enabled during prefill warmup # ensure each dp rank has a unique bootstrap_room during prefill warmup "bootstrap_room": [ diff --git a/test/srt/openai/conftest.py b/test/srt/openai/conftest.py index 26098fa4b..ed88d624b 100644 --- a/test/srt/openai/conftest.py +++ b/test/srt/openai/conftest.py @@ -12,9 +12,10 @@ import pytest import requests from sglang.srt.utils import kill_process_tree # reuse SGLang helper +from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST SERVER_MODULE = "sglang.srt.entrypoints.openai.api_server" -DEFAULT_MODEL = "dummy-model" +DEFAULT_MODEL = DEFAULT_SMALL_MODEL_NAME_FOR_TEST STARTUP_TIMEOUT = float(os.getenv("SGLANG_OPENAI_STARTUP_TIMEOUT", 120)) @@ -39,7 +40,7 @@ def _wait_until_healthy(proc: subprocess.Popen, base: str, timeout: float) -> No def launch_openai_server(model: str = DEFAULT_MODEL, **kw): - """Spawn the draft OpenAI-compatible server and wait until it’s ready.""" + """Spawn the draft OpenAI-compatible server and wait until it's ready.""" port = _pick_free_port() cmd = [ sys.executable, @@ -79,7 +80,7 @@ def launch_openai_server(model: str = DEFAULT_MODEL, **kw): @pytest.fixture(scope="session") def openai_server() -> Generator[str, None, None]: - """PyTest fixture that provides the server’s base URL and cleans up.""" + """PyTest fixture that provides the server's base URL and cleans up.""" proc, base, log_file = launch_openai_server() yield base kill_process_tree(proc.pid) diff --git a/test/srt/openai/test_protocol.py b/test/srt/openai/test_protocol.py index a14b3d717..d096ee97c 100644 --- a/test/srt/openai/test_protocol.py +++ b/test/srt/openai/test_protocol.py @@ -15,9 +15,9 @@ import json import time +import unittest from typing import Dict, List, Optional -import pytest from pydantic import ValidationError from sglang.srt.entrypoints.openai.protocol import ( @@ -64,18 +64,18 @@ from sglang.srt.entrypoints.openai.protocol import ( ) -class TestModelCard: +class TestModelCard(unittest.TestCase): """Test ModelCard protocol model""" def test_basic_model_card_creation(self): """Test basic model card creation with required fields""" card = ModelCard(id="test-model") - assert card.id == "test-model" - assert card.object == "model" - assert card.owned_by == "sglang" - assert isinstance(card.created, int) - assert card.root is None - assert card.max_model_len is None + self.assertEqual(card.id, "test-model") + self.assertEqual(card.object, "model") + self.assertEqual(card.owned_by, "sglang") + self.assertIsInstance(card.created, int) + self.assertIsNone(card.root) + self.assertIsNone(card.max_model_len) def test_model_card_with_optional_fields(self): """Test model card with optional fields""" @@ -85,28 +85,28 @@ class TestModelCard: max_model_len=2048, created=1234567890, ) - assert card.id == "test-model" - assert card.root == "/path/to/model" - assert card.max_model_len == 2048 - assert card.created == 1234567890 + self.assertEqual(card.id, "test-model") + self.assertEqual(card.root, "/path/to/model") + self.assertEqual(card.max_model_len, 2048) + self.assertEqual(card.created, 1234567890) def test_model_card_serialization(self): """Test model card JSON serialization""" card = ModelCard(id="test-model", max_model_len=4096) data = card.model_dump() - assert data["id"] == "test-model" - assert data["object"] == "model" - assert data["max_model_len"] == 4096 + self.assertEqual(data["id"], "test-model") + self.assertEqual(data["object"], "model") + self.assertEqual(data["max_model_len"], 4096) -class TestModelList: +class TestModelList(unittest.TestCase): """Test ModelList protocol model""" def test_empty_model_list(self): """Test empty model list creation""" model_list = ModelList() - assert model_list.object == "list" - assert len(model_list.data) == 0 + self.assertEqual(model_list.object, "list") + self.assertEqual(len(model_list.data), 0) def test_model_list_with_cards(self): """Test model list with model cards""" @@ -115,12 +115,12 @@ class TestModelList: ModelCard(id="model-2", max_model_len=2048), ] model_list = ModelList(data=cards) - assert len(model_list.data) == 2 - assert model_list.data[0].id == "model-1" - assert model_list.data[1].id == "model-2" + self.assertEqual(len(model_list.data), 2) + self.assertEqual(model_list.data[0].id, "model-1") + self.assertEqual(model_list.data[1].id, "model-2") -class TestErrorResponse: +class TestErrorResponse(unittest.TestCase): """Test ErrorResponse protocol model""" def test_basic_error_response(self): @@ -128,11 +128,11 @@ class TestErrorResponse: error = ErrorResponse( message="Invalid request", type="BadRequestError", code=400 ) - assert error.object == "error" - assert error.message == "Invalid request" - assert error.type == "BadRequestError" - assert error.code == 400 - assert error.param is None + self.assertEqual(error.object, "error") + self.assertEqual(error.message, "Invalid request") + self.assertEqual(error.type, "BadRequestError") + self.assertEqual(error.code, 400) + self.assertIsNone(error.param) def test_error_response_with_param(self): """Test error response with parameter""" @@ -142,19 +142,19 @@ class TestErrorResponse: code=422, param="temperature", ) - assert error.param == "temperature" + self.assertEqual(error.param, "temperature") -class TestUsageInfo: +class TestUsageInfo(unittest.TestCase): """Test UsageInfo protocol model""" def test_basic_usage_info(self): """Test basic usage info creation""" usage = UsageInfo(prompt_tokens=10, completion_tokens=20, total_tokens=30) - assert usage.prompt_tokens == 10 - assert usage.completion_tokens == 20 - assert usage.total_tokens == 30 - assert usage.prompt_tokens_details is None + self.assertEqual(usage.prompt_tokens, 10) + self.assertEqual(usage.completion_tokens, 20) + self.assertEqual(usage.total_tokens, 30) + self.assertIsNone(usage.prompt_tokens_details) def test_usage_info_with_cache_details(self): """Test usage info with cache details""" @@ -164,22 +164,22 @@ class TestUsageInfo: total_tokens=30, prompt_tokens_details={"cached_tokens": 5}, ) - assert usage.prompt_tokens_details == {"cached_tokens": 5} + self.assertEqual(usage.prompt_tokens_details, {"cached_tokens": 5}) -class TestCompletionRequest: +class TestCompletionRequest(unittest.TestCase): """Test CompletionRequest protocol model""" def test_basic_completion_request(self): """Test basic completion request""" request = CompletionRequest(model="test-model", prompt="Hello world") - assert request.model == "test-model" - assert request.prompt == "Hello world" - assert request.max_tokens == 16 # default - assert request.temperature == 1.0 # default - assert request.n == 1 # default - assert not request.stream # default - assert not request.echo # default + self.assertEqual(request.model, "test-model") + self.assertEqual(request.prompt, "Hello world") + self.assertEqual(request.max_tokens, 16) # default + self.assertEqual(request.temperature, 1.0) # default + self.assertEqual(request.n, 1) # default + self.assertFalse(request.stream) # default + self.assertFalse(request.echo) # default def test_completion_request_with_options(self): """Test completion request with various options""" @@ -195,15 +195,15 @@ class TestCompletionRequest: stop=[".", "!"], logprobs=5, ) - assert request.prompt == ["Hello", "world"] - assert request.max_tokens == 100 - assert request.temperature == 0.7 - assert request.top_p == 0.9 - assert request.n == 2 - assert request.stream - assert request.echo - assert request.stop == [".", "!"] - assert request.logprobs == 5 + self.assertEqual(request.prompt, ["Hello", "world"]) + self.assertEqual(request.max_tokens, 100) + self.assertEqual(request.temperature, 0.7) + self.assertEqual(request.top_p, 0.9) + self.assertEqual(request.n, 2) + self.assertTrue(request.stream) + self.assertTrue(request.echo) + self.assertEqual(request.stop, [".", "!"]) + self.assertEqual(request.logprobs, 5) def test_completion_request_sglang_extensions(self): """Test completion request with SGLang-specific extensions""" @@ -217,23 +217,23 @@ class TestCompletionRequest: json_schema='{"type": "object"}', lora_path="/path/to/lora", ) - assert request.top_k == 50 - assert request.min_p == 0.1 - assert request.repetition_penalty == 1.1 - assert request.regex == r"\d+" - assert request.json_schema == '{"type": "object"}' - assert request.lora_path == "/path/to/lora" + self.assertEqual(request.top_k, 50) + self.assertEqual(request.min_p, 0.1) + self.assertEqual(request.repetition_penalty, 1.1) + self.assertEqual(request.regex, r"\d+") + self.assertEqual(request.json_schema, '{"type": "object"}') + self.assertEqual(request.lora_path, "/path/to/lora") def test_completion_request_validation_errors(self): """Test completion request validation errors""" - with pytest.raises(ValidationError): + with self.assertRaises(ValidationError): CompletionRequest() # missing required fields - with pytest.raises(ValidationError): + with self.assertRaises(ValidationError): CompletionRequest(model="test-model") # missing prompt -class TestCompletionResponse: +class TestCompletionResponse(unittest.TestCase): """Test CompletionResponse protocol model""" def test_basic_completion_response(self): @@ -245,28 +245,28 @@ class TestCompletionResponse: response = CompletionResponse( id="test-id", model="test-model", choices=[choice], usage=usage ) - assert response.id == "test-id" - assert response.object == "text_completion" - assert response.model == "test-model" - assert len(response.choices) == 1 - assert response.choices[0].text == "Hello world!" - assert response.usage.total_tokens == 5 + self.assertEqual(response.id, "test-id") + self.assertEqual(response.object, "text_completion") + self.assertEqual(response.model, "test-model") + self.assertEqual(len(response.choices), 1) + self.assertEqual(response.choices[0].text, "Hello world!") + self.assertEqual(response.usage.total_tokens, 5) -class TestChatCompletionRequest: +class TestChatCompletionRequest(unittest.TestCase): """Test ChatCompletionRequest protocol model""" def test_basic_chat_completion_request(self): """Test basic chat completion request""" messages = [{"role": "user", "content": "Hello"}] request = ChatCompletionRequest(model="test-model", messages=messages) - assert request.model == "test-model" - assert len(request.messages) == 1 - assert request.messages[0].role == "user" - assert request.messages[0].content == "Hello" - assert request.temperature == 0.7 # default - assert not request.stream # default - assert request.tool_choice == "none" # default when no tools + self.assertEqual(request.model, "test-model") + self.assertEqual(len(request.messages), 1) + self.assertEqual(request.messages[0].role, "user") + self.assertEqual(request.messages[0].content, "Hello") + self.assertEqual(request.temperature, 0.7) # default + self.assertFalse(request.stream) # default + self.assertEqual(request.tool_choice, "none") # default when no tools def test_chat_completion_with_multimodal_content(self): """Test chat completion with multimodal content""" @@ -283,9 +283,9 @@ class TestChatCompletionRequest: } ] request = ChatCompletionRequest(model="test-model", messages=messages) - assert len(request.messages[0].content) == 2 - assert request.messages[0].content[0].type == "text" - assert request.messages[0].content[1].type == "image_url" + self.assertEqual(len(request.messages[0].content), 2) + self.assertEqual(request.messages[0].content[0].type, "text") + self.assertEqual(request.messages[0].content[1].type, "image_url") def test_chat_completion_with_tools(self): """Test chat completion with tools""" @@ -306,9 +306,9 @@ class TestChatCompletionRequest: request = ChatCompletionRequest( model="test-model", messages=messages, tools=tools ) - assert len(request.tools) == 1 - assert request.tools[0].function.name == "get_weather" - assert request.tool_choice == "auto" # default when tools present + self.assertEqual(len(request.tools), 1) + self.assertEqual(request.tools[0].function.name, "get_weather") + self.assertEqual(request.tool_choice, "auto") # default when tools present def test_chat_completion_tool_choice_validation(self): """Test tool choice validation logic""" @@ -316,7 +316,7 @@ class TestChatCompletionRequest: # No tools, tool_choice should default to "none" request1 = ChatCompletionRequest(model="test-model", messages=messages) - assert request1.tool_choice == "none" + self.assertEqual(request1.tool_choice, "none") # With tools, tool_choice should default to "auto" tools = [ @@ -328,7 +328,7 @@ class TestChatCompletionRequest: request2 = ChatCompletionRequest( model="test-model", messages=messages, tools=tools ) - assert request2.tool_choice == "auto" + self.assertEqual(request2.tool_choice, "auto") def test_chat_completion_sglang_extensions(self): """Test chat completion with SGLang extensions""" @@ -342,14 +342,14 @@ class TestChatCompletionRequest: stream_reasoning=False, chat_template_kwargs={"custom_param": "value"}, ) - assert request.top_k == 40 - assert request.min_p == 0.05 - assert not request.separate_reasoning - assert not request.stream_reasoning - assert request.chat_template_kwargs == {"custom_param": "value"} + self.assertEqual(request.top_k, 40) + self.assertEqual(request.min_p, 0.05) + self.assertFalse(request.separate_reasoning) + self.assertFalse(request.stream_reasoning) + self.assertEqual(request.chat_template_kwargs, {"custom_param": "value"}) -class TestChatCompletionResponse: +class TestChatCompletionResponse(unittest.TestCase): """Test ChatCompletionResponse protocol model""" def test_basic_chat_completion_response(self): @@ -362,11 +362,11 @@ class TestChatCompletionResponse: response = ChatCompletionResponse( id="test-id", model="test-model", choices=[choice], usage=usage ) - assert response.id == "test-id" - assert response.object == "chat.completion" - assert response.model == "test-model" - assert len(response.choices) == 1 - assert response.choices[0].message.content == "Hello there!" + self.assertEqual(response.id, "test-id") + self.assertEqual(response.object, "chat.completion") + self.assertEqual(response.model, "test-model") + self.assertEqual(len(response.choices), 1) + self.assertEqual(response.choices[0].message.content, "Hello there!") def test_chat_completion_response_with_tool_calls(self): """Test chat completion response with tool calls""" @@ -384,28 +384,30 @@ class TestChatCompletionResponse: response = ChatCompletionResponse( id="test-id", model="test-model", choices=[choice], usage=usage ) - assert response.choices[0].message.tool_calls[0].function.name == "get_weather" - assert response.choices[0].finish_reason == "tool_calls" + self.assertEqual( + response.choices[0].message.tool_calls[0].function.name, "get_weather" + ) + self.assertEqual(response.choices[0].finish_reason, "tool_calls") -class TestEmbeddingRequest: +class TestEmbeddingRequest(unittest.TestCase): """Test EmbeddingRequest protocol model""" def test_basic_embedding_request(self): """Test basic embedding request""" request = EmbeddingRequest(model="test-model", input="Hello world") - assert request.model == "test-model" - assert request.input == "Hello world" - assert request.encoding_format == "float" # default - assert request.dimensions is None # default + self.assertEqual(request.model, "test-model") + self.assertEqual(request.input, "Hello world") + self.assertEqual(request.encoding_format, "float") # default + self.assertIsNone(request.dimensions) # default def test_embedding_request_with_list_input(self): """Test embedding request with list input""" request = EmbeddingRequest( model="test-model", input=["Hello", "world"], dimensions=512 ) - assert request.input == ["Hello", "world"] - assert request.dimensions == 512 + self.assertEqual(request.input, ["Hello", "world"]) + self.assertEqual(request.dimensions, 512) def test_multimodal_embedding_request(self): """Test multimodal embedding request""" @@ -414,14 +416,14 @@ class TestEmbeddingRequest: MultimodalEmbeddingInput(text="World", image=None), ] request = EmbeddingRequest(model="test-model", input=multimodal_input) - assert len(request.input) == 2 - assert request.input[0].text == "Hello" - assert request.input[0].image == "base64_image_data" - assert request.input[1].text == "World" - assert request.input[1].image is None + self.assertEqual(len(request.input), 2) + self.assertEqual(request.input[0].text, "Hello") + self.assertEqual(request.input[0].image, "base64_image_data") + self.assertEqual(request.input[1].text, "World") + self.assertIsNone(request.input[1].image) -class TestEmbeddingResponse: +class TestEmbeddingResponse(unittest.TestCase): """Test EmbeddingResponse protocol model""" def test_basic_embedding_response(self): @@ -431,14 +433,14 @@ class TestEmbeddingResponse: response = EmbeddingResponse( data=[embedding_obj], model="test-model", usage=usage ) - assert response.object == "list" - assert len(response.data) == 1 - assert response.data[0].embedding == [0.1, 0.2, 0.3] - assert response.data[0].index == 0 - assert response.usage.prompt_tokens == 3 + self.assertEqual(response.object, "list") + self.assertEqual(len(response.data), 1) + self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3]) + self.assertEqual(response.data[0].index, 0) + self.assertEqual(response.usage.prompt_tokens, 3) -class TestScoringRequest: +class TestScoringRequest(unittest.TestCase): """Test ScoringRequest protocol model""" def test_basic_scoring_request(self): @@ -446,11 +448,11 @@ class TestScoringRequest: request = ScoringRequest( model="test-model", query="Hello", items=["World", "Earth"] ) - assert request.model == "test-model" - assert request.query == "Hello" - assert request.items == ["World", "Earth"] - assert not request.apply_softmax # default - assert not request.item_first # default + self.assertEqual(request.model, "test-model") + self.assertEqual(request.query, "Hello") + self.assertEqual(request.items, ["World", "Earth"]) + self.assertFalse(request.apply_softmax) # default + self.assertFalse(request.item_first) # default def test_scoring_request_with_token_ids(self): """Test scoring request with token IDs""" @@ -462,34 +464,34 @@ class TestScoringRequest: apply_softmax=True, item_first=True, ) - assert request.query == [1, 2, 3] - assert request.items == [[4, 5], [6, 7]] - assert request.label_token_ids == [8, 9] - assert request.apply_softmax - assert request.item_first + self.assertEqual(request.query, [1, 2, 3]) + self.assertEqual(request.items, [[4, 5], [6, 7]]) + self.assertEqual(request.label_token_ids, [8, 9]) + self.assertTrue(request.apply_softmax) + self.assertTrue(request.item_first) -class TestScoringResponse: +class TestScoringResponse(unittest.TestCase): """Test ScoringResponse protocol model""" def test_basic_scoring_response(self): """Test basic scoring response""" response = ScoringResponse(scores=[[0.1, 0.9], [0.3, 0.7]], model="test-model") - assert response.object == "scoring" - assert response.scores == [[0.1, 0.9], [0.3, 0.7]] - assert response.model == "test-model" - assert response.usage is None # default + self.assertEqual(response.object, "scoring") + self.assertEqual(response.scores, [[0.1, 0.9], [0.3, 0.7]]) + self.assertEqual(response.model, "test-model") + self.assertIsNone(response.usage) # default -class TestFileOperations: +class TestFileOperations(unittest.TestCase): """Test file operation protocol models""" def test_file_request(self): """Test file request model""" file_data = b"test file content" request = FileRequest(file=file_data, purpose="batch") - assert request.file == file_data - assert request.purpose == "batch" + self.assertEqual(request.file, file_data) + self.assertEqual(request.purpose, "batch") def test_file_response(self): """Test file response model""" @@ -500,20 +502,20 @@ class TestFileOperations: filename="test.jsonl", purpose="batch", ) - assert response.id == "file-123" - assert response.object == "file" - assert response.bytes == 1024 - assert response.filename == "test.jsonl" + self.assertEqual(response.id, "file-123") + self.assertEqual(response.object, "file") + self.assertEqual(response.bytes, 1024) + self.assertEqual(response.filename, "test.jsonl") def test_file_delete_response(self): """Test file delete response model""" response = FileDeleteResponse(id="file-123", deleted=True) - assert response.id == "file-123" - assert response.object == "file" - assert response.deleted + self.assertEqual(response.id, "file-123") + self.assertEqual(response.object, "file") + self.assertTrue(response.deleted) -class TestBatchOperations: +class TestBatchOperations(unittest.TestCase): """Test batch operation protocol models""" def test_batch_request(self): @@ -524,10 +526,10 @@ class TestBatchOperations: completion_window="24h", metadata={"custom": "value"}, ) - assert request.input_file_id == "file-123" - assert request.endpoint == "/v1/chat/completions" - assert request.completion_window == "24h" - assert request.metadata == {"custom": "value"} + self.assertEqual(request.input_file_id, "file-123") + self.assertEqual(request.endpoint, "/v1/chat/completions") + self.assertEqual(request.completion_window, "24h") + self.assertEqual(request.metadata, {"custom": "value"}) def test_batch_response(self): """Test batch response model""" @@ -538,20 +540,20 @@ class TestBatchOperations: completion_window="24h", created_at=1234567890, ) - assert response.id == "batch-123" - assert response.object == "batch" - assert response.status == "validating" # default - assert response.endpoint == "/v1/chat/completions" + self.assertEqual(response.id, "batch-123") + self.assertEqual(response.object, "batch") + self.assertEqual(response.status, "validating") # default + self.assertEqual(response.endpoint, "/v1/chat/completions") -class TestResponseFormats: +class TestResponseFormats(unittest.TestCase): """Test response format protocol models""" def test_basic_response_format(self): """Test basic response format""" format_obj = ResponseFormat(type="json_object") - assert format_obj.type == "json_object" - assert format_obj.json_schema is None + self.assertEqual(format_obj.type, "json_object") + self.assertIsNone(format_obj.json_schema) def test_json_schema_response_format(self): """Test JSON schema response format""" @@ -560,9 +562,9 @@ class TestResponseFormats: name="person_schema", description="Person schema", schema=schema ) format_obj = ResponseFormat(type="json_schema", json_schema=json_schema) - assert format_obj.type == "json_schema" - assert format_obj.json_schema.name == "person_schema" - assert format_obj.json_schema.schema_ == schema + self.assertEqual(format_obj.type, "json_schema") + self.assertEqual(format_obj.json_schema.name, "person_schema") + self.assertEqual(format_obj.json_schema.schema_, schema) def test_structural_tag_response_format(self): """Test structural tag response format""" @@ -576,12 +578,12 @@ class TestResponseFormats: format_obj = StructuralTagResponseFormat( type="structural_tag", structures=structures, triggers=["think"] ) - assert format_obj.type == "structural_tag" - assert len(format_obj.structures) == 1 - assert format_obj.triggers == ["think"] + self.assertEqual(format_obj.type, "structural_tag") + self.assertEqual(len(format_obj.structures), 1) + self.assertEqual(format_obj.triggers, ["think"]) -class TestLogProbs: +class TestLogProbs(unittest.TestCase): """Test LogProbs protocol models""" def test_basic_logprobs(self): @@ -592,9 +594,9 @@ class TestLogProbs: tokens=["Hello", " ", "world"], top_logprobs=[{"Hello": -0.1}, {" ": -0.2}, {"world": -0.3}], ) - assert len(logprobs.tokens) == 3 - assert logprobs.tokens == ["Hello", " ", "world"] - assert logprobs.token_logprobs == [-0.1, -0.2, -0.3] + self.assertEqual(len(logprobs.tokens), 3) + self.assertEqual(logprobs.tokens, ["Hello", " ", "world"]) + self.assertEqual(logprobs.token_logprobs, [-0.1, -0.2, -0.3]) def test_choice_logprobs(self): """Test ChoiceLogprobs model""" @@ -607,17 +609,17 @@ class TestLogProbs: ], ) choice_logprobs = ChoiceLogprobs(content=[token_logprob]) - assert len(choice_logprobs.content) == 1 - assert choice_logprobs.content[0].token == "Hello" + self.assertEqual(len(choice_logprobs.content), 1) + self.assertEqual(choice_logprobs.content[0].token, "Hello") -class TestStreamingModels: +class TestStreamingModels(unittest.TestCase): """Test streaming response models""" def test_stream_options(self): """Test StreamOptions model""" options = StreamOptions(include_usage=True) - assert options.include_usage + self.assertTrue(options.include_usage) def test_chat_completion_stream_response(self): """Test ChatCompletionStreamResponse model""" @@ -626,29 +628,29 @@ class TestStreamingModels: response = ChatCompletionStreamResponse( id="test-id", model="test-model", choices=[choice] ) - assert response.object == "chat.completion.chunk" - assert response.choices[0].delta.content == "Hello" + self.assertEqual(response.object, "chat.completion.chunk") + self.assertEqual(response.choices[0].delta.content, "Hello") -class TestValidationEdgeCases: +class TestValidationEdgeCases(unittest.TestCase): """Test edge cases and validation scenarios""" def test_empty_messages_validation(self): """Test validation with empty messages""" - with pytest.raises(ValidationError): + with self.assertRaises(ValidationError): ChatCompletionRequest(model="test-model", messages=[]) def test_invalid_tool_choice_type(self): """Test invalid tool choice type""" messages = [{"role": "user", "content": "Hello"}] - with pytest.raises(ValidationError): + with self.assertRaises(ValidationError): ChatCompletionRequest( model="test-model", messages=messages, tool_choice=123 ) def test_negative_token_limits(self): """Test negative token limits""" - with pytest.raises(ValidationError): + with self.assertRaises(ValidationError): CompletionRequest(model="test-model", prompt="Hello", max_tokens=-1) def test_invalid_temperature_range(self): @@ -656,7 +658,7 @@ class TestValidationEdgeCases: # Note: The current protocol doesn't enforce temperature range, # but this test documents expected behavior request = CompletionRequest(model="test-model", prompt="Hello", temperature=5.0) - assert request.temperature == 5.0 # Currently allowed + self.assertEqual(request.temperature, 5.0) # Currently allowed def test_model_serialization_roundtrip(self): """Test that models can be serialized and deserialized""" @@ -673,11 +675,11 @@ class TestValidationEdgeCases: # Deserialize back restored_request = ChatCompletionRequest(**data) - assert restored_request.model == original_request.model - assert restored_request.temperature == original_request.temperature - assert restored_request.max_tokens == original_request.max_tokens - assert len(restored_request.messages) == len(original_request.messages) + self.assertEqual(restored_request.model, original_request.model) + self.assertEqual(restored_request.temperature, original_request.temperature) + self.assertEqual(restored_request.max_tokens, original_request.max_tokens) + self.assertEqual(len(restored_request.messages), len(original_request.messages)) if __name__ == "__main__": - pytest.main([__file__]) + unittest.main(verbosity=2) diff --git a/test/srt/openai/test_server.py b/test/srt/openai/test_server.py index c397ce409..3de52f4cd 100644 --- a/test/srt/openai/test_server.py +++ b/test/srt/openai/test_server.py @@ -1,16 +1,52 @@ # sglang/test/srt/openai/test_server.py -import pytest import requests +from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST as MODEL_ID + def test_health(openai_server: str): r = requests.get(f"{openai_server}/health") - assert r.status_code == 200, r.text + assert r.status_code == 200 + # FastAPI returns an empty body → r.text == "" assert r.text == "" -@pytest.mark.xfail(reason="Endpoint skeleton not implemented yet") def test_models_endpoint(openai_server: str): r = requests.get(f"{openai_server}/v1/models") - # once implemented this should be 200 - assert r.status_code == 200 + assert r.status_code == 200, r.text + payload = r.json() + + # Basic contract + assert "data" in payload and isinstance(payload["data"], list) and payload["data"] + + # Validate fields of the first model card + first = payload["data"][0] + for key in ("id", "root", "max_model_len"): + assert key in first, f"missing {key} in {first}" + + # max_model_len must be positive + assert isinstance(first["max_model_len"], int) and first["max_model_len"] > 0 + + # The server should report the same model id we launched it with + ids = {m["id"] for m in payload["data"]} + assert MODEL_ID in ids + + +def test_get_model_info(openai_server: str): + r = requests.get(f"{openai_server}/get_model_info") + assert r.status_code == 200, r.text + info = r.json() + + expected_keys = {"model_path", "tokenizer_path", "is_generation"} + assert expected_keys.issubset(info.keys()) + + # model_path must end with the one we passed on the CLI + assert info["model_path"].endswith(MODEL_ID) + + # is_generation is documented as a boolean + assert isinstance(info["is_generation"], bool) + + +def test_unknown_route_returns_404(openai_server: str): + r = requests.get(f"{openai_server}/definitely-not-a-real-route") + assert r.status_code == 404 diff --git a/test/srt/openai/test_serving_chat.py b/test/srt/openai/test_serving_chat.py index b2015866b..6cb384e84 100644 --- a/test/srt/openai/test_serving_chat.py +++ b/test/srt/openai/test_serving_chat.py @@ -1,41 +1,44 @@ """ -Unit tests for the OpenAIServingChat class from serving_chat.py. - -These tests ensure that the refactored implementation maintains compatibility -with the original adapter.py functionality. +Unit-tests for OpenAIServingChat — rewritten to use only the std-lib 'unittest'. +Run with either: + python tests/test_serving_chat_unit.py -v +or + python -m unittest discover -s tests -p "test_*unit.py" -v """ +import unittest import uuid +from typing import Optional from unittest.mock import Mock, patch -import pytest from fastapi import Request -from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest, ErrorResponse +from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat from sglang.srt.managers.io_struct import GenerateReqInput -# Mock TokenizerManager since it may not be directly importable in tests -class MockTokenizerManager: - def __init__(self): - self.model_config = Mock() - self.model_config.is_multimodal = False - self.server_args = Mock() - self.server_args.enable_cache_report = False - self.server_args.tool_call_parser = "hermes" - self.server_args.reasoning_parser = None - self.chat_template_name = "llama-3" +class _MockTokenizerManager: + """Minimal mock that satisfies OpenAIServingChat.""" - # Mock tokenizer + def __init__(self): + self.model_config = Mock(is_multimodal=False) + self.server_args = Mock( + enable_cache_report=False, + tool_call_parser="hermes", + reasoning_parser=None, + ) + self.chat_template_name: Optional[str] = "llama-3" + + # tokenizer stub self.tokenizer = Mock() - self.tokenizer.encode = Mock(return_value=[1, 2, 3, 4, 5]) - self.tokenizer.decode = Mock(return_value="Test response") + self.tokenizer.encode.return_value = [1, 2, 3, 4, 5] + self.tokenizer.decode.return_value = "Test response" self.tokenizer.chat_template = None self.tokenizer.bos_token_id = 1 - # Mock generate_request method - async def mock_generate(): + # async generator stub for generate_request + async def _mock_generate(): yield { "text": "Test response", "meta_info": { @@ -50,585 +53,176 @@ class MockTokenizerManager: "index": 0, } - self.generate_request = Mock(return_value=mock_generate()) - self.create_abort_task = Mock(return_value=None) + self.generate_request = Mock(return_value=_mock_generate()) + self.create_abort_task = Mock() -@pytest.fixture -def mock_tokenizer_manager(): - """Create a mock tokenizer manager for testing.""" - return MockTokenizerManager() +class ServingChatTestCase(unittest.TestCase): + # ------------- common fixtures ------------- + def setUp(self): + self.tm = _MockTokenizerManager() + self.chat = OpenAIServingChat(self.tm) + # frequently reused requests + self.basic_req = ChatCompletionRequest( + model="x", + messages=[{"role": "user", "content": "Hi?"}], + temperature=0.7, + max_tokens=100, + stream=False, + ) + self.stream_req = ChatCompletionRequest( + model="x", + messages=[{"role": "user", "content": "Hi?"}], + temperature=0.7, + max_tokens=100, + stream=True, + ) -@pytest.fixture -def serving_chat(mock_tokenizer_manager): - """Create a OpenAIServingChat instance for testing.""" - return OpenAIServingChat(mock_tokenizer_manager) + self.fastapi_request = Mock(spec=Request) + self.fastapi_request.headers = {} - -@pytest.fixture -def mock_request(): - """Create a mock FastAPI request.""" - request = Mock(spec=Request) - request.headers = {} - return request - - -@pytest.fixture -def basic_chat_request(): - """Create a basic chat completion request.""" - return ChatCompletionRequest( - model="test-model", - messages=[{"role": "user", "content": "Hello, how are you?"}], - temperature=0.7, - max_tokens=100, - stream=False, - ) - - -@pytest.fixture -def streaming_chat_request(): - """Create a streaming chat completion request.""" - return ChatCompletionRequest( - model="test-model", - messages=[{"role": "user", "content": "Hello, how are you?"}], - temperature=0.7, - max_tokens=100, - stream=True, - ) - - -class TestOpenAIServingChatConversion: - """Test request conversion methods.""" - - def test_convert_to_internal_request_single( - self, serving_chat, basic_chat_request, mock_tokenizer_manager - ): - """Test converting single request to internal format.""" + # ------------- conversion tests ------------- + def test_convert_to_internal_request_single(self): with patch( "sglang.srt.entrypoints.openai.serving_chat.generate_chat_conv" - ) as mock_conv: - mock_conv_instance = Mock() - mock_conv_instance.get_prompt.return_value = "Test prompt" - mock_conv_instance.image_data = None - mock_conv_instance.audio_data = None - mock_conv_instance.modalities = [] - mock_conv_instance.stop_str = [""] - mock_conv.return_value = mock_conv_instance + ) as conv_mock, patch.object(self.chat, "_process_messages") as proc_mock: + conv_ins = Mock() + conv_ins.get_prompt.return_value = "Test prompt" + conv_ins.image_data = conv_ins.audio_data = None + conv_ins.modalities = [] + conv_ins.stop_str = [""] + conv_mock.return_value = conv_ins - # Mock the _process_messages method to return expected values - with patch.object(serving_chat, "_process_messages") as mock_process: - mock_process.return_value = ( - "Test prompt", - [1, 2, 3], - None, - None, - [], - [""], - None, # tool_call_constraint - ) + proc_mock.return_value = ( + "Test prompt", + [1, 2, 3], + None, + None, + [], + [""], + None, + ) - adapted_request, processed_request = ( - serving_chat._convert_to_internal_request( - [basic_chat_request], ["test-id"] - ) - ) + adapted, processed = self.chat._convert_to_internal_request( + [self.basic_req], ["rid"] + ) + self.assertIsInstance(adapted, GenerateReqInput) + self.assertFalse(adapted.stream) + self.assertEqual(processed, self.basic_req) - assert isinstance(adapted_request, GenerateReqInput) - assert adapted_request.stream == basic_chat_request.stream - assert processed_request == basic_chat_request - - -class TestToolCalls: - """Test tool call functionality from adapter.py""" - - def test_tool_call_request_conversion(self, serving_chat): - """Test request with tool calls""" - request = ChatCompletionRequest( - model="test-model", - messages=[{"role": "user", "content": "What's the weather?"}], + # ------------- tool-call branch ------------- + def test_tool_call_request_conversion(self): + req = ChatCompletionRequest( + model="x", + messages=[{"role": "user", "content": "Weather?"}], tools=[ { "type": "function", "function": { "name": "get_weather", - "description": "Get weather information", - "parameters": { - "type": "object", - "properties": {"location": {"type": "string"}}, - }, + "parameters": {"type": "object", "properties": {}}, }, } ], tool_choice="auto", ) - with patch.object(serving_chat, "_process_messages") as mock_process: - mock_process.return_value = ( - "Test prompt", - [1, 2, 3], - None, - None, - [], - [""], - None, # tool_call_constraint - ) + with patch.object( + self.chat, + "_process_messages", + return_value=("Prompt", [1, 2, 3], None, None, [], [""], None), + ): + adapted, _ = self.chat._convert_to_internal_request([req], ["rid"]) + self.assertEqual(adapted.rid, "rid") - adapted_request, _ = serving_chat._convert_to_internal_request( - [request], ["test-id"] - ) - - assert adapted_request.rid == "test-id" - # Tool call constraint should be processed - assert request.tools is not None - - def test_tool_choice_none(self, serving_chat): - """Test tool_choice=none disables tool calls""" - request = ChatCompletionRequest( - model="test-model", - messages=[{"role": "user", "content": "Hello"}], - tools=[{"type": "function", "function": {"name": "test_func"}}], + def test_tool_choice_none(self): + req = ChatCompletionRequest( + model="x", + messages=[{"role": "user", "content": "Hi"}], + tools=[{"type": "function", "function": {"name": "noop"}}], tool_choice="none", ) + with patch.object( + self.chat, + "_process_messages", + return_value=("Prompt", [1, 2, 3], None, None, [], [""], None), + ): + adapted, _ = self.chat._convert_to_internal_request([req], ["rid"]) + self.assertEqual(adapted.rid, "rid") - with patch.object(serving_chat, "_process_messages") as mock_process: - mock_process.return_value = ( - "Test prompt", - [1, 2, 3], - None, - None, - [], - [""], - None, # tool_call_constraint - ) + # ------------- multimodal branch ------------- + def test_multimodal_request_with_images(self): + self.tm.model_config.is_multimodal = True - adapted_request, _ = serving_chat._convert_to_internal_request( - [request], ["test-id"] - ) - - # Tools should not be processed when tool_choice is "none" - assert adapted_request.rid == "test-id" - - def test_tool_call_response_processing(self, serving_chat): - """Test processing tool calls in response""" - mock_ret_item = { - "text": '{"name": "get_weather", "parameters": {"location": "Paris"}}', - "meta_info": { - "output_token_logprobs": [], - "output_top_logprobs": None, - }, - } - - tools = [ - { - "type": "function", - "function": { - "name": "get_weather", - "parameters": { - "type": "object", - "properties": {"location": {"type": "string"}}, - }, - }, - } - ] - - finish_reason = {"type": "stop", "matched": None} - - # Mock FunctionCallParser - with patch( - "sglang.srt.entrypoints.openai.serving_chat.FunctionCallParser" - ) as mock_parser_class: - mock_parser = Mock() - mock_parser.has_tool_call.return_value = True - - # Create proper mock tool call object - mock_tool_call = Mock() - mock_tool_call.name = "get_weather" - mock_tool_call.parameters = '{"location": "Paris"}' - - mock_parser.parse_non_stream.return_value = ("", [mock_tool_call]) - mock_parser_class.return_value = mock_parser - - tool_calls, text, updated_finish_reason = serving_chat._process_tool_calls( - mock_ret_item["text"], tools, "hermes", finish_reason - ) - - assert tool_calls is not None - assert len(tool_calls) == 1 - assert updated_finish_reason["type"] == "tool_calls" - - -class TestMultimodalContent: - """Test multimodal content handling from adapter.py""" - - def test_multimodal_request_with_images(self, serving_chat): - """Test request with image content""" - request = ChatCompletionRequest( - model="test-model", + req = ChatCompletionRequest( + model="x", messages=[ { "role": "user", "content": [ - {"type": "text", "text": "What's in this image?"}, + {"type": "text", "text": "What's in the image?"}, { "type": "image_url", - "image_url": {"url": "data:image/jpeg;base64,..."}, + "image_url": {"url": "data:image/jpeg;base64,"}, }, ], } ], ) - # Set multimodal mode - serving_chat.tokenizer_manager.model_config.is_multimodal = True + with patch.object( + self.chat, + "_apply_jinja_template", + return_value=("prompt", [1, 2], ["img"], None, [], []), + ), patch.object( + self.chat, + "_apply_conversation_template", + return_value=("prompt", ["img"], None, [], []), + ): + out = self.chat._process_messages(req, True) + _, _, image_data, *_ = out + self.assertEqual(image_data, ["img"]) - with patch.object(serving_chat, "_apply_jinja_template") as mock_apply: - mock_apply.return_value = ( - "prompt", - [1, 2, 3], - ["image_data"], - None, - [], - [], - ) - - with patch.object( - serving_chat, "_apply_conversation_template" - ) as mock_conv: - mock_conv.return_value = ("prompt", ["image_data"], None, [], []) - - ( - prompt, - prompt_ids, - image_data, - audio_data, - modalities, - stop, - tool_call_constraint, - ) = serving_chat._process_messages(request, True) - - assert image_data == ["image_data"] - assert prompt == "prompt" - - def test_multimodal_request_with_audio(self, serving_chat): - """Test request with audio content""" - request = ChatCompletionRequest( - model="test-model", - messages=[ - { - "role": "user", - "content": [ - {"type": "text", "text": "Transcribe this audio"}, - { - "type": "audio_url", - "audio_url": {"url": "data:audio/wav;base64,UklGR..."}, - }, - ], - } - ], + # ------------- template handling ------------- + def test_jinja_template_processing(self): + req = ChatCompletionRequest( + model="x", messages=[{"role": "user", "content": "Hello"}] ) + self.tm.chat_template_name = None + self.tm.tokenizer.chat_template = "" - serving_chat.tokenizer_manager.model_config.is_multimodal = True + with patch.object( + self.chat, + "_apply_jinja_template", + return_value=("processed", [1], None, None, [], [""]), + ), patch("builtins.hasattr", return_value=True): + prompt, prompt_ids, *_ = self.chat._process_messages(req, False) + self.assertEqual(prompt, "processed") + self.assertEqual(prompt_ids, [1]) - with patch.object(serving_chat, "_apply_jinja_template") as mock_apply: - mock_apply.return_value = ( - "prompt", - [1, 2, 3], - None, - ["audio_data"], - ["audio"], - [], - ) - - with patch.object( - serving_chat, "_apply_conversation_template" - ) as mock_conv: - mock_conv.return_value = ("prompt", None, ["audio_data"], ["audio"], []) - - ( - prompt, - prompt_ids, - image_data, - audio_data, - modalities, - stop, - tool_call_constraint, - ) = serving_chat._process_messages(request, True) - - assert audio_data == ["audio_data"] - assert modalities == ["audio"] - - -class TestTemplateHandling: - """Test chat template handling from adapter.py""" - - def test_jinja_template_processing(self, serving_chat): - """Test Jinja template processing""" - request = ChatCompletionRequest( - model="test-model", messages=[{"role": "user", "content": "Hello"}] - ) - - # Mock the template attribute directly - serving_chat.tokenizer_manager.chat_template_name = None - serving_chat.tokenizer_manager.tokenizer.chat_template = "" - - with patch.object(serving_chat, "_apply_jinja_template") as mock_apply: - mock_apply.return_value = ( - "processed_prompt", - [1, 2, 3], - None, - None, - [], - [""], - ) - - # Mock hasattr to simulate the None check - with patch("builtins.hasattr") as mock_hasattr: - mock_hasattr.return_value = True - - ( - prompt, - prompt_ids, - image_data, - audio_data, - modalities, - stop, - tool_call_constraint, - ) = serving_chat._process_messages(request, False) - - assert prompt == "processed_prompt" - assert prompt_ids == [1, 2, 3] - - def test_conversation_template_processing(self, serving_chat): - """Test conversation template processing""" - request = ChatCompletionRequest( - model="test-model", messages=[{"role": "user", "content": "Hello"}] - ) - - serving_chat.tokenizer_manager.chat_template_name = "llama-3" - - with patch.object(serving_chat, "_apply_conversation_template") as mock_apply: - mock_apply.return_value = ("conv_prompt", None, None, [], [""]) - - ( - prompt, - prompt_ids, - image_data, - audio_data, - modalities, - stop, - tool_call_constraint, - ) = serving_chat._process_messages(request, False) - - assert prompt == "conv_prompt" - assert stop == [""] - - def test_continue_final_message(self, serving_chat): - """Test continue_final_message functionality""" - request = ChatCompletionRequest( - model="test-model", - messages=[ - {"role": "user", "content": "Hello"}, - {"role": "assistant", "content": "Hi there"}, - ], - continue_final_message=True, - ) - - with patch.object(serving_chat, "_apply_conversation_template") as mock_apply: - mock_apply.return_value = ("Hi there", None, None, [], [""]) - - ( - prompt, - prompt_ids, - image_data, - audio_data, - modalities, - stop, - tool_call_constraint, - ) = serving_chat._process_messages(request, False) - - # Should handle continue_final_message properly - assert prompt == "Hi there" - - -class TestReasoningContent: - """Test reasoning content separation from adapter.py""" - - def test_reasoning_content_request(self, serving_chat): - """Test request with reasoning content separation""" - request = ChatCompletionRequest( - model="test-model", - messages=[{"role": "user", "content": "Solve this math problem"}], - separate_reasoning=True, - stream_reasoning=False, - ) - - with patch.object(serving_chat, "_process_messages") as mock_process: - mock_process.return_value = ( - "Test prompt", - [1, 2, 3], - None, - None, - [], - [""], - None, # tool_call_constraint - ) - - adapted_request, _ = serving_chat._convert_to_internal_request( - [request], ["test-id"] - ) - - assert adapted_request.rid == "test-id" - assert request.separate_reasoning == True - - def test_reasoning_content_response(self, serving_chat): - """Test reasoning content in response""" - mock_ret_item = { - "text": "This is reasoningAnswer: 42", - "meta_info": { - "output_token_logprobs": [], - "output_top_logprobs": None, - }, - } - - # Mock ReasoningParser - with patch( - "sglang.srt.entrypoints.openai.serving_chat.ReasoningParser" - ) as mock_parser_class: - mock_parser = Mock() - mock_parser.parse_non_stream.return_value = ( - "This is reasoning", - "Answer: 42", - ) - mock_parser_class.return_value = mock_parser - - choice_logprobs = None - reasoning_text = None - text = mock_ret_item["text"] - - # Simulate reasoning processing - enable_thinking = True - if enable_thinking: - parser = mock_parser_class(model_type="test", stream_reasoning=False) - reasoning_text, text = parser.parse_non_stream(text) - - assert reasoning_text == "This is reasoning" - assert text == "Answer: 42" - - -class TestSamplingParams: - """Test sampling parameter handling from adapter.py""" - - def test_all_sampling_parameters(self, serving_chat): - """Test all sampling parameters are properly handled""" - request = ChatCompletionRequest( - model="test-model", - messages=[{"role": "user", "content": "Hello"}], + # ------------- sampling-params ------------- + def test_sampling_param_build(self): + req = ChatCompletionRequest( + model="x", + messages=[{"role": "user", "content": "Hi"}], temperature=0.8, max_tokens=150, - max_completion_tokens=200, min_tokens=5, top_p=0.9, - top_k=50, - min_p=0.1, - presence_penalty=0.1, - frequency_penalty=0.2, - repetition_penalty=1.1, - stop=["<|endoftext|>"], - stop_token_ids=[13, 14], - regex=r"\d+", - ebnf=" ::= ", - n=2, - no_stop_trim=True, - ignore_eos=True, - skip_special_tokens=False, - logit_bias={"1": 0.5, "2": -0.3}, + stop=[""], ) + with patch.object( + self.chat, + "_process_messages", + return_value=("Prompt", [1], None, None, [], [""], None), + ): + params = self.chat._build_sampling_params(req, [""], None) + self.assertEqual(params["temperature"], 0.8) + self.assertEqual(params["max_new_tokens"], 150) + self.assertEqual(params["min_new_tokens"], 5) + self.assertEqual(params["stop"], [""]) - with patch.object(serving_chat, "_process_messages") as mock_process: - mock_process.return_value = ( - "Test prompt", - [1, 2, 3], - None, - None, - [], - [""], - None, # tool_call_constraint - ) - sampling_params = serving_chat._build_sampling_params( - request, [""], None - ) - - # Verify all parameters - assert sampling_params["temperature"] == 0.8 - assert sampling_params["max_new_tokens"] == 150 - assert sampling_params["min_new_tokens"] == 5 - assert sampling_params["top_p"] == 0.9 - assert sampling_params["top_k"] == 50 - assert sampling_params["min_p"] == 0.1 - assert sampling_params["presence_penalty"] == 0.1 - assert sampling_params["frequency_penalty"] == 0.2 - assert sampling_params["repetition_penalty"] == 1.1 - assert sampling_params["stop"] == [""] - assert sampling_params["logit_bias"] == {"1": 0.5, "2": -0.3} - - def test_response_format_json_schema(self, serving_chat): - """Test response format with JSON schema""" - request = ChatCompletionRequest( - model="test-model", - messages=[{"role": "user", "content": "Generate JSON"}], - response_format={ - "type": "json_schema", - "json_schema": { - "name": "response", - "schema": { - "type": "object", - "properties": {"answer": {"type": "string"}}, - }, - }, - }, - ) - - with patch.object(serving_chat, "_process_messages") as mock_process: - mock_process.return_value = ( - "Test prompt", - [1, 2, 3], - None, - None, - [], - [""], - None, # tool_call_constraint - ) - - sampling_params = serving_chat._build_sampling_params( - request, [""], None - ) - - assert "json_schema" in sampling_params - assert '"type": "object"' in sampling_params["json_schema"] - - def test_response_format_json_object(self, serving_chat): - """Test response format with JSON object""" - request = ChatCompletionRequest( - model="test-model", - messages=[{"role": "user", "content": "Generate JSON"}], - response_format={"type": "json_object"}, - ) - - with patch.object(serving_chat, "_process_messages") as mock_process: - mock_process.return_value = ( - "Test prompt", - [1, 2, 3], - None, - None, - [], - [""], - None, # tool_call_constraint - ) - - sampling_params = serving_chat._build_sampling_params( - request, [""], None - ) - - assert sampling_params["json_schema"] == '{"type": "object"}' +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/test/srt/openai/test_serving_completions.py b/test/srt/openai/test_serving_completions.py index 3e8fc42c8..be4415667 100644 --- a/test/srt/openai/test_serving_completions.py +++ b/test/srt/openai/test_serving_completions.py @@ -1,176 +1,101 @@ """ -Tests for the refactored completions serving handler +Unit-tests for the refactored completions-serving handler (no pytest). +Run with: + python -m unittest tests.test_serving_completions_unit -v """ +import unittest from unittest.mock import AsyncMock, Mock, patch -import pytest - -from sglang.srt.entrypoints.openai.protocol import ( - CompletionRequest, - CompletionResponse, - CompletionResponseChoice, - CompletionStreamResponse, - ErrorResponse, -) +from sglang.srt.entrypoints.openai.protocol import CompletionRequest from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompletion -from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.tokenizer_manager import TokenizerManager -@pytest.fixture -def mock_tokenizer_manager(): - """Create a mock tokenizer manager""" - manager = Mock(spec=TokenizerManager) +class ServingCompletionTestCase(unittest.TestCase): + """Bundle all prompt/echo tests in one TestCase.""" - # Mock tokenizer - manager.tokenizer = Mock() - manager.tokenizer.encode = Mock(return_value=[1, 2, 3, 4]) - manager.tokenizer.decode = Mock(return_value="decoded text") - manager.tokenizer.bos_token_id = 1 + # ---------- shared test fixtures ---------- + def setUp(self): + # build the mock TokenizerManager once for every test + tm = Mock(spec=TokenizerManager) - # Mock model config - manager.model_config = Mock() - manager.model_config.is_multimodal = False + tm.tokenizer = Mock() + tm.tokenizer.encode.return_value = [1, 2, 3, 4] + tm.tokenizer.decode.return_value = "decoded text" + tm.tokenizer.bos_token_id = 1 - # Mock server args - manager.server_args = Mock() - manager.server_args.enable_cache_report = False + tm.model_config = Mock(is_multimodal=False) + tm.server_args = Mock(enable_cache_report=False) - # Mock generation - manager.generate_request = AsyncMock() - manager.create_abort_task = Mock(return_value=None) + tm.generate_request = AsyncMock() + tm.create_abort_task = Mock() - return manager + self.sc = OpenAIServingCompletion(tm) + # ---------- prompt-handling ---------- + def test_single_string_prompt(self): + req = CompletionRequest(model="x", prompt="Hello world", max_tokens=100) + internal, _ = self.sc._convert_to_internal_request([req], ["id"]) + self.assertEqual(internal.text, "Hello world") -@pytest.fixture -def serving_completion(mock_tokenizer_manager): - """Create a OpenAIServingCompletion instance""" - return OpenAIServingCompletion(mock_tokenizer_manager) + def test_single_token_ids_prompt(self): + req = CompletionRequest(model="x", prompt=[1, 2, 3, 4], max_tokens=100) + internal, _ = self.sc._convert_to_internal_request([req], ["id"]) + self.assertEqual(internal.input_ids, [1, 2, 3, 4]) - -class TestPromptHandling: - """Test different prompt types and formats from adapter.py""" - - def test_single_string_prompt(self, serving_completion): - """Test handling single string prompt""" - request = CompletionRequest( - model="test-model", prompt="Hello world", max_tokens=100 + def test_completion_template_handling(self): + req = CompletionRequest( + model="x", prompt="def f():", suffix="return 1", max_tokens=100 ) - - adapted_request, _ = serving_completion._convert_to_internal_request( - [request], ["test-id"] - ) - - assert adapted_request.text == "Hello world" - - def test_single_token_ids_prompt(self, serving_completion): - """Test handling single token IDs prompt""" - request = CompletionRequest( - model="test-model", prompt=[1, 2, 3, 4], max_tokens=100 - ) - - adapted_request, _ = serving_completion._convert_to_internal_request( - [request], ["test-id"] - ) - - assert adapted_request.input_ids == [1, 2, 3, 4] - - def test_completion_template_handling(self, serving_completion): - """Test completion template processing""" - request = CompletionRequest( - model="test-model", - prompt="def hello():", - suffix="return 'world'", - max_tokens=100, - ) - with patch( "sglang.srt.entrypoints.openai.serving_completions.is_completion_template_defined", return_value=True, + ), patch( + "sglang.srt.entrypoints.openai.serving_completions.generate_completion_prompt_from_request", + return_value="processed_prompt", ): - with patch( - "sglang.srt.entrypoints.openai.serving_completions.generate_completion_prompt_from_request", - return_value="processed_prompt", - ): - adapted_request, _ = serving_completion._convert_to_internal_request( - [request], ["test-id"] - ) + internal, _ = self.sc._convert_to_internal_request([req], ["id"]) + self.assertEqual(internal.text, "processed_prompt") - assert adapted_request.text == "processed_prompt" + # ---------- echo-handling ---------- + def test_echo_with_string_prompt_streaming(self): + req = CompletionRequest(model="x", prompt="Hello", max_tokens=1, echo=True) + self.assertEqual(self.sc._get_echo_text(req, 0), "Hello") - -class TestEchoHandling: - """Test echo functionality from adapter.py""" - - def test_echo_with_string_prompt_streaming(self, serving_completion): - """Test echo handling with string prompt in streaming""" - request = CompletionRequest( - model="test-model", prompt="Hello", max_tokens=100, echo=True + def test_echo_with_list_of_strings_streaming(self): + req = CompletionRequest( + model="x", prompt=["A", "B"], max_tokens=1, echo=True, n=1 ) + self.assertEqual(self.sc._get_echo_text(req, 0), "A") + self.assertEqual(self.sc._get_echo_text(req, 1), "B") - # Test _get_echo_text method - echo_text = serving_completion._get_echo_text(request, 0) - assert echo_text == "Hello" + def test_echo_with_token_ids_streaming(self): + req = CompletionRequest(model="x", prompt=[1, 2, 3], max_tokens=1, echo=True) + self.sc.tokenizer_manager.tokenizer.decode.return_value = "decoded_prompt" + self.assertEqual(self.sc._get_echo_text(req, 0), "decoded_prompt") - def test_echo_with_list_of_strings_streaming(self, serving_completion): - """Test echo handling with list of strings in streaming""" - request = CompletionRequest( - model="test-model", - prompt=["Hello", "World"], - max_tokens=100, - echo=True, - n=1, + def test_echo_with_multiple_token_ids_streaming(self): + req = CompletionRequest( + model="x", prompt=[[1, 2], [3, 4]], max_tokens=1, echo=True, n=1 ) + self.sc.tokenizer_manager.tokenizer.decode.return_value = "decoded" + self.assertEqual(self.sc._get_echo_text(req, 0), "decoded") - echo_text = serving_completion._get_echo_text(request, 0) - assert echo_text == "Hello" + def test_prepare_echo_prompts_non_streaming(self): + # single string + req = CompletionRequest(model="x", prompt="Hi", echo=True) + self.assertEqual(self.sc._prepare_echo_prompts(req), ["Hi"]) - echo_text = serving_completion._get_echo_text(request, 1) - assert echo_text == "World" + # list of strings + req = CompletionRequest(model="x", prompt=["Hi", "Yo"], echo=True) + self.assertEqual(self.sc._prepare_echo_prompts(req), ["Hi", "Yo"]) - def test_echo_with_token_ids_streaming(self, serving_completion): - """Test echo handling with token IDs in streaming""" - request = CompletionRequest( - model="test-model", prompt=[1, 2, 3], max_tokens=100, echo=True - ) + # token IDs + req = CompletionRequest(model="x", prompt=[1, 2, 3], echo=True) + self.sc.tokenizer_manager.tokenizer.decode.return_value = "decoded" + self.assertEqual(self.sc._prepare_echo_prompts(req), ["decoded"]) - serving_completion.tokenizer_manager.tokenizer.decode.return_value = ( - "decoded_prompt" - ) - echo_text = serving_completion._get_echo_text(request, 0) - assert echo_text == "decoded_prompt" - def test_echo_with_multiple_token_ids_streaming(self, serving_completion): - """Test echo handling with multiple token ID prompts in streaming""" - request = CompletionRequest( - model="test-model", prompt=[[1, 2], [3, 4]], max_tokens=100, echo=True, n=1 - ) - - serving_completion.tokenizer_manager.tokenizer.decode.return_value = "decoded" - echo_text = serving_completion._get_echo_text(request, 0) - assert echo_text == "decoded" - - def test_prepare_echo_prompts_non_streaming(self, serving_completion): - """Test prepare echo prompts for non-streaming response""" - # Test with single string - request = CompletionRequest(model="test-model", prompt="Hello", echo=True) - - echo_prompts = serving_completion._prepare_echo_prompts(request) - assert echo_prompts == ["Hello"] - - # Test with list of strings - request = CompletionRequest( - model="test-model", prompt=["Hello", "World"], echo=True - ) - - echo_prompts = serving_completion._prepare_echo_prompts(request) - assert echo_prompts == ["Hello", "World"] - - # Test with token IDs - request = CompletionRequest(model="test-model", prompt=[1, 2, 3], echo=True) - - serving_completion.tokenizer_manager.tokenizer.decode.return_value = "decoded" - echo_prompts = serving_completion._prepare_echo_prompts(request) - assert echo_prompts == ["decoded"] +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/test/srt/openai/test_serving_embedding.py b/test/srt/openai/test_serving_embedding.py index 58438bba8..137db8fe5 100644 --- a/test/srt/openai/test_serving_embedding.py +++ b/test/srt/openai/test_serving_embedding.py @@ -8,11 +8,11 @@ with the original adapter.py functionality and follows OpenAI API specifications import asyncio import json import time +import unittest import uuid from typing import Any, Dict, List from unittest.mock import AsyncMock, Mock, patch -import pytest from fastapi import Request from fastapi.responses import ORJSONResponse from pydantic_core import ValidationError @@ -30,7 +30,7 @@ from sglang.srt.managers.io_struct import EmbeddingReqInput # Mock TokenizerManager for embedding tests -class MockTokenizerManager: +class _MockTokenizerManager: def __init__(self): self.model_config = Mock() self.model_config.is_multimodal = False @@ -58,141 +58,98 @@ class MockTokenizerManager: self.generate_request = Mock(return_value=mock_generate_embedding()) -@pytest.fixture -def mock_tokenizer_manager(): - """Create a mock tokenizer manager for testing.""" - return MockTokenizerManager() +class ServingEmbeddingTestCase(unittest.TestCase): + def setUp(self): + """Set up test fixtures.""" + self.tokenizer_manager = _MockTokenizerManager() + self.serving_embedding = OpenAIServingEmbedding(self.tokenizer_manager) + self.request = Mock(spec=Request) + self.request.headers = {} -@pytest.fixture -def serving_embedding(mock_tokenizer_manager): - """Create an OpenAIServingEmbedding instance for testing.""" - return OpenAIServingEmbedding(mock_tokenizer_manager) + self.basic_req = EmbeddingRequest( + model="test-model", + input="Hello, how are you?", + encoding_format="float", + ) + self.list_req = EmbeddingRequest( + model="test-model", + input=["Hello, how are you?", "I am fine, thank you!"], + encoding_format="float", + ) + self.multimodal_req = EmbeddingRequest( + model="test-model", + input=[ + MultimodalEmbeddingInput(text="Hello", image="base64_image_data"), + MultimodalEmbeddingInput(text="World", image=None), + ], + encoding_format="float", + ) + self.token_ids_req = EmbeddingRequest( + model="test-model", + input=[1, 2, 3, 4, 5], + encoding_format="float", + ) - -@pytest.fixture -def mock_request(): - """Create a mock FastAPI request.""" - request = Mock(spec=Request) - request.headers = {} - return request - - -@pytest.fixture -def basic_embedding_request(): - """Create a basic embedding request.""" - return EmbeddingRequest( - model="test-model", - input="Hello, how are you?", - encoding_format="float", - ) - - -@pytest.fixture -def list_embedding_request(): - """Create an embedding request with list input.""" - return EmbeddingRequest( - model="test-model", - input=["Hello, how are you?", "I am fine, thank you!"], - encoding_format="float", - ) - - -@pytest.fixture -def multimodal_embedding_request(): - """Create a multimodal embedding request.""" - return EmbeddingRequest( - model="test-model", - input=[ - MultimodalEmbeddingInput(text="Hello", image="base64_image_data"), - MultimodalEmbeddingInput(text="World", image=None), - ], - encoding_format="float", - ) - - -@pytest.fixture -def token_ids_embedding_request(): - """Create an embedding request with token IDs.""" - return EmbeddingRequest( - model="test-model", - input=[1, 2, 3, 4, 5], - encoding_format="float", - ) - - -class TestOpenAIServingEmbeddingConversion: - """Test request conversion methods.""" - - def test_convert_single_string_request( - self, serving_embedding, basic_embedding_request - ): + def test_convert_single_string_request(self): """Test converting single string request to internal format.""" adapted_request, processed_request = ( - serving_embedding._convert_to_internal_request( - [basic_embedding_request], ["test-id"] + self.serving_embedding._convert_to_internal_request( + [self.basic_req], ["test-id"] ) ) - assert isinstance(adapted_request, EmbeddingReqInput) - assert adapted_request.text == "Hello, how are you?" - assert adapted_request.rid == "test-id" - assert processed_request == basic_embedding_request + self.assertIsInstance(adapted_request, EmbeddingReqInput) + self.assertEqual(adapted_request.text, "Hello, how are you?") + self.assertEqual(adapted_request.rid, "test-id") + self.assertEqual(processed_request, self.basic_req) - def test_convert_list_string_request( - self, serving_embedding, list_embedding_request - ): + def test_convert_list_string_request(self): """Test converting list of strings request to internal format.""" adapted_request, processed_request = ( - serving_embedding._convert_to_internal_request( - [list_embedding_request], ["test-id"] + self.serving_embedding._convert_to_internal_request( + [self.list_req], ["test-id"] ) ) - assert isinstance(adapted_request, EmbeddingReqInput) - assert adapted_request.text == ["Hello, how are you?", "I am fine, thank you!"] - assert adapted_request.rid == "test-id" - assert processed_request == list_embedding_request + self.assertIsInstance(adapted_request, EmbeddingReqInput) + self.assertEqual( + adapted_request.text, ["Hello, how are you?", "I am fine, thank you!"] + ) + self.assertEqual(adapted_request.rid, "test-id") + self.assertEqual(processed_request, self.list_req) - def test_convert_token_ids_request( - self, serving_embedding, token_ids_embedding_request - ): + def test_convert_token_ids_request(self): """Test converting token IDs request to internal format.""" adapted_request, processed_request = ( - serving_embedding._convert_to_internal_request( - [token_ids_embedding_request], ["test-id"] + self.serving_embedding._convert_to_internal_request( + [self.token_ids_req], ["test-id"] ) ) - assert isinstance(adapted_request, EmbeddingReqInput) - assert adapted_request.input_ids == [1, 2, 3, 4, 5] - assert adapted_request.rid == "test-id" - assert processed_request == token_ids_embedding_request + self.assertIsInstance(adapted_request, EmbeddingReqInput) + self.assertEqual(adapted_request.input_ids, [1, 2, 3, 4, 5]) + self.assertEqual(adapted_request.rid, "test-id") + self.assertEqual(processed_request, self.token_ids_req) - def test_convert_multimodal_request( - self, serving_embedding, multimodal_embedding_request - ): + def test_convert_multimodal_request(self): """Test converting multimodal request to internal format.""" adapted_request, processed_request = ( - serving_embedding._convert_to_internal_request( - [multimodal_embedding_request], ["test-id"] + self.serving_embedding._convert_to_internal_request( + [self.multimodal_req], ["test-id"] ) ) - assert isinstance(adapted_request, EmbeddingReqInput) + self.assertIsInstance(adapted_request, EmbeddingReqInput) # Should extract text and images separately - assert len(adapted_request.text) == 2 - assert "Hello" in adapted_request.text - assert "World" in adapted_request.text - assert adapted_request.image_data[0] == "base64_image_data" - assert adapted_request.image_data[1] is None - assert adapted_request.rid == "test-id" + self.assertEqual(len(adapted_request.text), 2) + self.assertIn("Hello", adapted_request.text) + self.assertIn("World", adapted_request.text) + self.assertEqual(adapted_request.image_data[0], "base64_image_data") + self.assertIsNone(adapted_request.image_data[1]) + self.assertEqual(adapted_request.rid, "test-id") - -class TestEmbeddingResponseBuilding: - """Test response building methods.""" - - def test_build_single_embedding_response(self, serving_embedding): + def test_build_single_embedding_response(self): """Test building response for single embedding.""" ret_data = [ { @@ -201,19 +158,21 @@ class TestEmbeddingResponseBuilding: } ] - response = serving_embedding._build_embedding_response(ret_data, "test-model") + response = self.serving_embedding._build_embedding_response( + ret_data, "test-model" + ) - assert isinstance(response, EmbeddingResponse) - assert response.model == "test-model" - assert len(response.data) == 1 - assert response.data[0].embedding == [0.1, 0.2, 0.3, 0.4, 0.5] - assert response.data[0].index == 0 - assert response.data[0].object == "embedding" - assert response.usage.prompt_tokens == 5 - assert response.usage.total_tokens == 5 - assert response.usage.completion_tokens == 0 + self.assertIsInstance(response, EmbeddingResponse) + self.assertEqual(response.model, "test-model") + self.assertEqual(len(response.data), 1) + self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3, 0.4, 0.5]) + self.assertEqual(response.data[0].index, 0) + self.assertEqual(response.data[0].object, "embedding") + self.assertEqual(response.usage.prompt_tokens, 5) + self.assertEqual(response.usage.total_tokens, 5) + self.assertEqual(response.usage.completion_tokens, 0) - def test_build_multiple_embedding_response(self, serving_embedding): + def test_build_multiple_embedding_response(self): """Test building response for multiple embeddings.""" ret_data = [ { @@ -226,25 +185,20 @@ class TestEmbeddingResponseBuilding: }, ] - response = serving_embedding._build_embedding_response(ret_data, "test-model") + response = self.serving_embedding._build_embedding_response( + ret_data, "test-model" + ) - assert isinstance(response, EmbeddingResponse) - assert len(response.data) == 2 - assert response.data[0].embedding == [0.1, 0.2, 0.3] - assert response.data[0].index == 0 - assert response.data[1].embedding == [0.4, 0.5, 0.6] - assert response.data[1].index == 1 - assert response.usage.prompt_tokens == 7 # 3 + 4 - assert response.usage.total_tokens == 7 + self.assertIsInstance(response, EmbeddingResponse) + self.assertEqual(len(response.data), 2) + self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3]) + self.assertEqual(response.data[0].index, 0) + self.assertEqual(response.data[1].embedding, [0.4, 0.5, 0.6]) + self.assertEqual(response.data[1].index, 1) + self.assertEqual(response.usage.prompt_tokens, 7) # 3 + 4 + self.assertEqual(response.usage.total_tokens, 7) - -@pytest.mark.asyncio -class TestOpenAIServingEmbeddingAsyncMethods: - """Test async methods of OpenAIServingEmbedding.""" - - async def test_handle_request_success( - self, serving_embedding, basic_embedding_request, mock_request - ): + async def test_handle_request_success(self): """Test successful embedding request handling.""" # Mock the generate_request to return expected data @@ -254,32 +208,30 @@ class TestOpenAIServingEmbeddingAsyncMethods: "meta_info": {"prompt_tokens": 5}, } - serving_embedding.tokenizer_manager.generate_request = Mock( + self.serving_embedding.tokenizer_manager.generate_request = Mock( return_value=mock_generate() ) - response = await serving_embedding.handle_request( - basic_embedding_request, mock_request + response = await self.serving_embedding.handle_request( + self.basic_req, self.request ) - assert isinstance(response, EmbeddingResponse) - assert len(response.data) == 1 - assert response.data[0].embedding == [0.1, 0.2, 0.3, 0.4, 0.5] + self.assertIsInstance(response, EmbeddingResponse) + self.assertEqual(len(response.data), 1) + self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3, 0.4, 0.5]) - async def test_handle_request_validation_error( - self, serving_embedding, mock_request - ): + async def test_handle_request_validation_error(self): """Test handling request with validation error.""" invalid_request = EmbeddingRequest(model="test-model", input="") - response = await serving_embedding.handle_request(invalid_request, mock_request) + response = await self.serving_embedding.handle_request( + invalid_request, self.request + ) - assert isinstance(response, ORJSONResponse) - assert response.status_code == 400 + self.assertIsInstance(response, ORJSONResponse) + self.assertEqual(response.status_code, 400) - async def test_handle_request_generation_error( - self, serving_embedding, basic_embedding_request, mock_request - ): + async def test_handle_request_generation_error(self): """Test handling request with generation error.""" # Mock generate_request to raise an error @@ -287,30 +239,32 @@ class TestOpenAIServingEmbeddingAsyncMethods: raise ValueError("Generation failed") yield # This won't be reached but needed for async generator - serving_embedding.tokenizer_manager.generate_request = Mock( + self.serving_embedding.tokenizer_manager.generate_request = Mock( return_value=mock_generate_error() ) - response = await serving_embedding.handle_request( - basic_embedding_request, mock_request + response = await self.serving_embedding.handle_request( + self.basic_req, self.request ) - assert isinstance(response, ORJSONResponse) - assert response.status_code == 400 + self.assertIsInstance(response, ORJSONResponse) + self.assertEqual(response.status_code, 400) - async def test_handle_request_internal_error( - self, serving_embedding, basic_embedding_request, mock_request - ): + async def test_handle_request_internal_error(self): """Test handling request with internal server error.""" # Mock _convert_to_internal_request to raise an exception with patch.object( - serving_embedding, + self.serving_embedding, "_convert_to_internal_request", side_effect=Exception("Internal error"), ): - response = await serving_embedding.handle_request( - basic_embedding_request, mock_request + response = await self.serving_embedding.handle_request( + self.basic_req, self.request ) - assert isinstance(response, ORJSONResponse) - assert response.status_code == 500 + self.assertIsInstance(response, ORJSONResponse) + self.assertEqual(response.status_code, 500) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 06b09f9fa..58142e73b 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -62,6 +62,11 @@ suites = { TestFile("test_openai_adapter.py", 1), TestFile("test_openai_function_calling.py", 60), TestFile("test_openai_server.py", 149), + TestFile("openai/test_server.py", 120), + TestFile("openai/test_protocol.py", 60), + TestFile("openai/test_serving_chat.py", 120), + TestFile("openai/test_serving_completions.py", 120), + TestFile("openai/test_serving_embedding.py", 120), TestFile("test_openai_server_hidden_states.py", 240), TestFile("test_penalty.py", 41), TestFile("test_page_size.py", 60),