feat(oai refactor): Replace openai_api with entrypoints/openai (#7351)
Co-authored-by: Jin Pan <jpan236@wisc.edu>
This commit is contained in:
@@ -57,11 +57,21 @@ class _MockTokenizerManager:
|
||||
self.create_abort_task = Mock()
|
||||
|
||||
|
||||
class _MockTemplateManager:
|
||||
"""Minimal mock for TemplateManager."""
|
||||
|
||||
def __init__(self):
|
||||
self.chat_template_name: Optional[str] = "llama-3"
|
||||
self.jinja_template_content_format: Optional[str] = None
|
||||
self.completion_template_name: Optional[str] = None
|
||||
|
||||
|
||||
class ServingChatTestCase(unittest.TestCase):
|
||||
# ------------- common fixtures -------------
|
||||
def setUp(self):
|
||||
self.tm = _MockTokenizerManager()
|
||||
self.chat = OpenAIServingChat(self.tm)
|
||||
self.template_manager = _MockTemplateManager()
|
||||
self.chat = OpenAIServingChat(self.tm, self.template_manager)
|
||||
|
||||
# frequently reused requests
|
||||
self.basic_req = ChatCompletionRequest(
|
||||
@@ -109,96 +119,6 @@ class ServingChatTestCase(unittest.TestCase):
|
||||
self.assertFalse(adapted.stream)
|
||||
self.assertEqual(processed, self.basic_req)
|
||||
|
||||
# # ------------- tool-call branch -------------
|
||||
# def test_tool_call_request_conversion(self):
|
||||
# req = ChatCompletionRequest(
|
||||
# model="x",
|
||||
# messages=[{"role": "user", "content": "Weather?"}],
|
||||
# tools=[
|
||||
# {
|
||||
# "type": "function",
|
||||
# "function": {
|
||||
# "name": "get_weather",
|
||||
# "parameters": {"type": "object", "properties": {}},
|
||||
# },
|
||||
# }
|
||||
# ],
|
||||
# tool_choice="auto",
|
||||
# )
|
||||
|
||||
# with patch.object(
|
||||
# self.chat,
|
||||
# "_process_messages",
|
||||
# return_value=("Prompt", [1, 2, 3], None, None, [], ["</s>"], None),
|
||||
# ):
|
||||
# adapted, _ = self.chat._convert_to_internal_request(req, "rid")
|
||||
# self.assertEqual(adapted.rid, "rid")
|
||||
|
||||
# def test_tool_choice_none(self):
|
||||
# req = ChatCompletionRequest(
|
||||
# model="x",
|
||||
# messages=[{"role": "user", "content": "Hi"}],
|
||||
# tools=[{"type": "function", "function": {"name": "noop"}}],
|
||||
# tool_choice="none",
|
||||
# )
|
||||
# with patch.object(
|
||||
# self.chat,
|
||||
# "_process_messages",
|
||||
# return_value=("Prompt", [1, 2, 3], None, None, [], ["</s>"], None),
|
||||
# ):
|
||||
# adapted, _ = self.chat._convert_to_internal_request(req, "rid")
|
||||
# self.assertEqual(adapted.rid, "rid")
|
||||
|
||||
# ------------- multimodal branch -------------
|
||||
def test_multimodal_request_with_images(self):
|
||||
self.tm.model_config.is_multimodal = True
|
||||
|
||||
req = ChatCompletionRequest(
|
||||
model="x",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in the image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:image/jpeg;base64,"},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
self.chat,
|
||||
"_apply_jinja_template",
|
||||
return_value=("prompt", [1, 2], ["img"], None, [], []),
|
||||
), patch.object(
|
||||
self.chat,
|
||||
"_apply_conversation_template",
|
||||
return_value=("prompt", ["img"], None, [], []),
|
||||
):
|
||||
out = self.chat._process_messages(req, True)
|
||||
_, _, image_data, *_ = out
|
||||
self.assertEqual(image_data, ["img"])
|
||||
|
||||
# ------------- template handling -------------
|
||||
def test_jinja_template_processing(self):
|
||||
req = ChatCompletionRequest(
|
||||
model="x", messages=[{"role": "user", "content": "Hello"}]
|
||||
)
|
||||
self.tm.chat_template_name = None
|
||||
self.tm.tokenizer.chat_template = "<jinja>"
|
||||
|
||||
with patch.object(
|
||||
self.chat,
|
||||
"_apply_jinja_template",
|
||||
return_value=("processed", [1], None, None, [], ["</s>"]),
|
||||
), patch("builtins.hasattr", return_value=True):
|
||||
prompt, prompt_ids, *_ = self.chat._process_messages(req, False)
|
||||
self.assertEqual(prompt, "processed")
|
||||
self.assertEqual(prompt_ids, [1])
|
||||
|
||||
# ------------- sampling-params -------------
|
||||
def test_sampling_param_build(self):
|
||||
req = ChatCompletionRequest(
|
||||
|
||||
Reference in New Issue
Block a user