Refine OpenAI serving entrypoint to remove batch requests (#7372)
Signed-off-by: Xinyuan Tong <justinning0323@outlook.com> Co-authored-by: Chang Su <csu272@usc.edu>
This commit is contained in:
@@ -104,52 +104,50 @@ class ServingChatTestCase(unittest.TestCase):
|
||||
None,
|
||||
)
|
||||
|
||||
adapted, processed = self.chat._convert_to_internal_request(
|
||||
[self.basic_req], ["rid"]
|
||||
)
|
||||
adapted, processed = self.chat._convert_to_internal_request(self.basic_req)
|
||||
self.assertIsInstance(adapted, GenerateReqInput)
|
||||
self.assertFalse(adapted.stream)
|
||||
self.assertEqual(processed, self.basic_req)
|
||||
|
||||
# ------------- tool-call branch -------------
|
||||
def test_tool_call_request_conversion(self):
|
||||
req = ChatCompletionRequest(
|
||||
model="x",
|
||||
messages=[{"role": "user", "content": "Weather?"}],
|
||||
tools=[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
}
|
||||
],
|
||||
tool_choice="auto",
|
||||
)
|
||||
# # ------------- tool-call branch -------------
|
||||
# def test_tool_call_request_conversion(self):
|
||||
# req = ChatCompletionRequest(
|
||||
# model="x",
|
||||
# messages=[{"role": "user", "content": "Weather?"}],
|
||||
# tools=[
|
||||
# {
|
||||
# "type": "function",
|
||||
# "function": {
|
||||
# "name": "get_weather",
|
||||
# "parameters": {"type": "object", "properties": {}},
|
||||
# },
|
||||
# }
|
||||
# ],
|
||||
# tool_choice="auto",
|
||||
# )
|
||||
|
||||
with patch.object(
|
||||
self.chat,
|
||||
"_process_messages",
|
||||
return_value=("Prompt", [1, 2, 3], None, None, [], ["</s>"], None),
|
||||
):
|
||||
adapted, _ = self.chat._convert_to_internal_request([req], ["rid"])
|
||||
self.assertEqual(adapted.rid, "rid")
|
||||
# with patch.object(
|
||||
# self.chat,
|
||||
# "_process_messages",
|
||||
# return_value=("Prompt", [1, 2, 3], None, None, [], ["</s>"], None),
|
||||
# ):
|
||||
# adapted, _ = self.chat._convert_to_internal_request(req, "rid")
|
||||
# self.assertEqual(adapted.rid, "rid")
|
||||
|
||||
def test_tool_choice_none(self):
|
||||
req = ChatCompletionRequest(
|
||||
model="x",
|
||||
messages=[{"role": "user", "content": "Hi"}],
|
||||
tools=[{"type": "function", "function": {"name": "noop"}}],
|
||||
tool_choice="none",
|
||||
)
|
||||
with patch.object(
|
||||
self.chat,
|
||||
"_process_messages",
|
||||
return_value=("Prompt", [1, 2, 3], None, None, [], ["</s>"], None),
|
||||
):
|
||||
adapted, _ = self.chat._convert_to_internal_request([req], ["rid"])
|
||||
self.assertEqual(adapted.rid, "rid")
|
||||
# def test_tool_choice_none(self):
|
||||
# req = ChatCompletionRequest(
|
||||
# model="x",
|
||||
# messages=[{"role": "user", "content": "Hi"}],
|
||||
# tools=[{"type": "function", "function": {"name": "noop"}}],
|
||||
# tool_choice="none",
|
||||
# )
|
||||
# with patch.object(
|
||||
# self.chat,
|
||||
# "_process_messages",
|
||||
# return_value=("Prompt", [1, 2, 3], None, None, [], ["</s>"], None),
|
||||
# ):
|
||||
# adapted, _ = self.chat._convert_to_internal_request(req, "rid")
|
||||
# self.assertEqual(adapted.rid, "rid")
|
||||
|
||||
# ------------- multimodal branch -------------
|
||||
def test_multimodal_request_with_images(self):
|
||||
|
||||
@@ -36,12 +36,12 @@ class ServingCompletionTestCase(unittest.TestCase):
|
||||
# ---------- prompt-handling ----------
|
||||
def test_single_string_prompt(self):
|
||||
req = CompletionRequest(model="x", prompt="Hello world", max_tokens=100)
|
||||
internal, _ = self.sc._convert_to_internal_request([req], ["id"])
|
||||
internal, _ = self.sc._convert_to_internal_request(req)
|
||||
self.assertEqual(internal.text, "Hello world")
|
||||
|
||||
def test_single_token_ids_prompt(self):
|
||||
req = CompletionRequest(model="x", prompt=[1, 2, 3, 4], max_tokens=100)
|
||||
internal, _ = self.sc._convert_to_internal_request([req], ["id"])
|
||||
internal, _ = self.sc._convert_to_internal_request(req)
|
||||
self.assertEqual(internal.input_ids, [1, 2, 3, 4])
|
||||
|
||||
def test_completion_template_handling(self):
|
||||
@@ -55,7 +55,7 @@ class ServingCompletionTestCase(unittest.TestCase):
|
||||
"sglang.srt.entrypoints.openai.serving_completions.generate_completion_prompt_from_request",
|
||||
return_value="processed_prompt",
|
||||
):
|
||||
internal, _ = self.sc._convert_to_internal_request([req], ["id"])
|
||||
internal, _ = self.sc._convert_to_internal_request(req)
|
||||
self.assertEqual(internal.text, "processed_prompt")
|
||||
|
||||
# ---------- echo-handling ----------
|
||||
|
||||
@@ -94,50 +94,42 @@ class ServingEmbeddingTestCase(unittest.TestCase):
|
||||
def test_convert_single_string_request(self):
|
||||
"""Test converting single string request to internal format."""
|
||||
adapted_request, processed_request = (
|
||||
self.serving_embedding._convert_to_internal_request(
|
||||
self.basic_req, "test-id"
|
||||
)
|
||||
self.serving_embedding._convert_to_internal_request(self.basic_req)
|
||||
)
|
||||
|
||||
self.assertIsInstance(adapted_request, EmbeddingReqInput)
|
||||
self.assertEqual(adapted_request.text, "Hello, how are you?")
|
||||
self.assertEqual(adapted_request.rid, None)
|
||||
# self.assertEqual(adapted_request.rid, "test-id")
|
||||
self.assertEqual(processed_request, self.basic_req)
|
||||
|
||||
def test_convert_list_string_request(self):
|
||||
"""Test converting list of strings request to internal format."""
|
||||
adapted_request, processed_request = (
|
||||
self.serving_embedding._convert_to_internal_request(
|
||||
self.list_req, "test-id"
|
||||
)
|
||||
self.serving_embedding._convert_to_internal_request(self.list_req)
|
||||
)
|
||||
|
||||
self.assertIsInstance(adapted_request, EmbeddingReqInput)
|
||||
self.assertEqual(
|
||||
adapted_request.text, ["Hello, how are you?", "I am fine, thank you!"]
|
||||
)
|
||||
self.assertEqual(adapted_request.rid, None)
|
||||
# self.assertEqual(adapted_request.rid, "test-id")
|
||||
self.assertEqual(processed_request, self.list_req)
|
||||
|
||||
def test_convert_token_ids_request(self):
|
||||
"""Test converting token IDs request to internal format."""
|
||||
adapted_request, processed_request = (
|
||||
self.serving_embedding._convert_to_internal_request(
|
||||
self.token_ids_req, "test-id"
|
||||
)
|
||||
self.serving_embedding._convert_to_internal_request(self.token_ids_req)
|
||||
)
|
||||
|
||||
self.assertIsInstance(adapted_request, EmbeddingReqInput)
|
||||
self.assertEqual(adapted_request.input_ids, [1, 2, 3, 4, 5])
|
||||
self.assertEqual(adapted_request.rid, None)
|
||||
# self.assertEqual(adapted_request.rid, "test-id")
|
||||
self.assertEqual(processed_request, self.token_ids_req)
|
||||
|
||||
def test_convert_multimodal_request(self):
|
||||
"""Test converting multimodal request to internal format."""
|
||||
adapted_request, processed_request = (
|
||||
self.serving_embedding._convert_to_internal_request(
|
||||
self.multimodal_req, "test-id"
|
||||
)
|
||||
self.serving_embedding._convert_to_internal_request(self.multimodal_req)
|
||||
)
|
||||
|
||||
self.assertIsInstance(adapted_request, EmbeddingReqInput)
|
||||
@@ -147,7 +139,7 @@ class ServingEmbeddingTestCase(unittest.TestCase):
|
||||
self.assertIn("World", adapted_request.text)
|
||||
self.assertEqual(adapted_request.image_data[0], "base64_image_data")
|
||||
self.assertIsNone(adapted_request.image_data[1])
|
||||
self.assertEqual(adapted_request.rid, None)
|
||||
# self.assertEqual(adapted_request.rid, "test-id")
|
||||
|
||||
def test_build_single_embedding_response(self):
|
||||
"""Test building response for single embedding."""
|
||||
@@ -194,72 +186,86 @@ class ServingEmbeddingTestCase(unittest.TestCase):
|
||||
self.assertEqual(response.usage.prompt_tokens, 7) # 3 + 4
|
||||
self.assertEqual(response.usage.total_tokens, 7)
|
||||
|
||||
async def test_handle_request_success(self):
|
||||
def test_handle_request_success(self):
|
||||
"""Test successful embedding request handling."""
|
||||
|
||||
# Mock the generate_request to return expected data
|
||||
async def mock_generate():
|
||||
yield {
|
||||
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
"meta_info": {"prompt_tokens": 5},
|
||||
}
|
||||
async def run_test():
|
||||
# Mock the generate_request to return expected data
|
||||
async def mock_generate():
|
||||
yield {
|
||||
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
"meta_info": {"prompt_tokens": 5},
|
||||
}
|
||||
|
||||
self.serving_embedding.tokenizer_manager.generate_request = Mock(
|
||||
return_value=mock_generate()
|
||||
)
|
||||
self.serving_embedding.tokenizer_manager.generate_request = Mock(
|
||||
return_value=mock_generate()
|
||||
)
|
||||
|
||||
response = await self.serving_embedding.handle_request(
|
||||
self.basic_req, self.request
|
||||
)
|
||||
response = await self.serving_embedding.handle_request(
|
||||
self.basic_req, self.request
|
||||
)
|
||||
|
||||
self.assertIsInstance(response, EmbeddingResponse)
|
||||
self.assertEqual(len(response.data), 1)
|
||||
self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3, 0.4, 0.5])
|
||||
self.assertIsInstance(response, EmbeddingResponse)
|
||||
self.assertEqual(len(response.data), 1)
|
||||
self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3, 0.4, 0.5])
|
||||
|
||||
async def test_handle_request_validation_error(self):
|
||||
asyncio.run(run_test())
|
||||
|
||||
def test_handle_request_validation_error(self):
|
||||
"""Test handling request with validation error."""
|
||||
invalid_request = EmbeddingRequest(model="test-model", input="")
|
||||
|
||||
response = await self.serving_embedding.handle_request(
|
||||
invalid_request, self.request
|
||||
)
|
||||
async def run_test():
|
||||
invalid_request = EmbeddingRequest(model="test-model", input="")
|
||||
|
||||
self.assertIsInstance(response, ORJSONResponse)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
response = await self.serving_embedding.handle_request(
|
||||
invalid_request, self.request
|
||||
)
|
||||
|
||||
async def test_handle_request_generation_error(self):
|
||||
self.assertIsInstance(response, ORJSONResponse)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
|
||||
asyncio.run(run_test())
|
||||
|
||||
def test_handle_request_generation_error(self):
|
||||
"""Test handling request with generation error."""
|
||||
|
||||
# Mock generate_request to raise an error
|
||||
async def mock_generate_error():
|
||||
raise ValueError("Generation failed")
|
||||
yield # This won't be reached but needed for async generator
|
||||
async def run_test():
|
||||
# Mock generate_request to raise an error
|
||||
async def mock_generate_error():
|
||||
raise ValueError("Generation failed")
|
||||
yield # This won't be reached but needed for async generator
|
||||
|
||||
self.serving_embedding.tokenizer_manager.generate_request = Mock(
|
||||
return_value=mock_generate_error()
|
||||
)
|
||||
self.serving_embedding.tokenizer_manager.generate_request = Mock(
|
||||
return_value=mock_generate_error()
|
||||
)
|
||||
|
||||
response = await self.serving_embedding.handle_request(
|
||||
self.basic_req, self.request
|
||||
)
|
||||
|
||||
self.assertIsInstance(response, ORJSONResponse)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
|
||||
async def test_handle_request_internal_error(self):
|
||||
"""Test handling request with internal server error."""
|
||||
# Mock _convert_to_internal_request to raise an exception
|
||||
with patch.object(
|
||||
self.serving_embedding,
|
||||
"_convert_to_internal_request",
|
||||
side_effect=Exception("Internal error"),
|
||||
):
|
||||
response = await self.serving_embedding.handle_request(
|
||||
self.basic_req, self.request
|
||||
)
|
||||
|
||||
self.assertIsInstance(response, ORJSONResponse)
|
||||
self.assertEqual(response.status_code, 500)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
|
||||
asyncio.run(run_test())
|
||||
|
||||
def test_handle_request_internal_error(self):
|
||||
"""Test handling request with internal server error."""
|
||||
|
||||
async def run_test():
|
||||
# Mock _convert_to_internal_request to raise an exception
|
||||
with patch.object(
|
||||
self.serving_embedding,
|
||||
"_convert_to_internal_request",
|
||||
side_effect=Exception("Internal error"),
|
||||
):
|
||||
response = await self.serving_embedding.handle_request(
|
||||
self.basic_req, self.request
|
||||
)
|
||||
|
||||
self.assertIsInstance(response, ORJSONResponse)
|
||||
self.assertEqual(response.status_code, 500)
|
||||
|
||||
asyncio.run(run_test())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user