forked from EngineX-Cambricon/enginex-mlu370-vllm
add qwen3
This commit is contained in:
@@ -0,0 +1,160 @@
|
||||
from typing import List
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.entrypoints.openai.tool_parsers.utils import (
|
||||
run_tool_extraction, run_tool_extraction_streaming)
|
||||
from vllm.entrypoints.openai.protocol import FunctionCall
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||
|
||||
# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1
|
||||
SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')"
|
||||
SIMPLE_FUNCTION_CALL = FunctionCall(
|
||||
name="get_weather",
|
||||
arguments='{"city": "San Francisco", "metric": "celsius"}',
|
||||
)
|
||||
MORE_TYPES_FUNCTION_OUTPUT = (
|
||||
"register_user(name='John Doe', "
|
||||
"age=37, "
|
||||
"address={'city': 'San Francisco', 'state': 'CA'}, "
|
||||
"role=None, "
|
||||
"passed_test=True, "
|
||||
"aliases=['John', 'Johnny'])")
|
||||
MORE_TYPES_FUNCTION_CALL = FunctionCall(
|
||||
name="register_user",
|
||||
arguments='{"name": "John Doe", '
|
||||
'"age": 37, '
|
||||
'"address": {"city": "San Francisco", "state": "CA"}, '
|
||||
'"role": null, '
|
||||
'"passed_test": true, '
|
||||
'"aliases": ["John", "Johnny"]}',
|
||||
)
|
||||
PARAMETERLESS_FUNCTION_OUTPUT = "get_weather()"
|
||||
PARAMETERLESS_FUNCTION_CALL = FunctionCall(
|
||||
name="get_weather",
|
||||
arguments='{}',
|
||||
)
|
||||
EMPTY_DICT_FUNCTION_OUTPUT = "do_something_cool(additional_data={})"
|
||||
EMPTY_DICT_FUNCTION_CALL = FunctionCall(
|
||||
name="do_something_cool",
|
||||
arguments='{"additional_data": {}}',
|
||||
)
|
||||
EMPTY_LIST_FUNCTION_OUTPUT = "do_something_cool(steps=[])"
|
||||
EMPTY_LIST_FUNCTION_CALL = FunctionCall(
|
||||
name="do_something_cool",
|
||||
arguments='{"steps": []}',
|
||||
)
|
||||
ESCAPED_STRING_FUNCTION_OUTPUT = (
|
||||
r"get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')")
|
||||
ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
|
||||
name="get_weather",
|
||||
arguments='{"city": "Martha\'s Vineyard", "metric": "\\"cool units\\""}',
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [True, False])
|
||||
def test_no_tool_call(streaming: bool):
|
||||
mock_tokenizer = MagicMock()
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
|
||||
mock_tokenizer)
|
||||
model_output = "How can I help you today?"
|
||||
|
||||
content, tool_calls = run_tool_extraction(tool_parser,
|
||||
model_output,
|
||||
streaming=streaming)
|
||||
|
||||
assert content == model_output
|
||||
assert len(tool_calls) == 0
|
||||
|
||||
|
||||
TEST_CASES = [
|
||||
pytest.param(True,
|
||||
f"[{SIMPLE_FUNCTION_OUTPUT}]", [SIMPLE_FUNCTION_CALL],
|
||||
id="simple_streaming"),
|
||||
pytest.param(False,
|
||||
f"[{SIMPLE_FUNCTION_OUTPUT}]", [SIMPLE_FUNCTION_CALL],
|
||||
id="simple_nonstreaming"),
|
||||
pytest.param(True,
|
||||
f"[{MORE_TYPES_FUNCTION_OUTPUT}]", [MORE_TYPES_FUNCTION_CALL],
|
||||
id="more_types_streaming"),
|
||||
pytest.param(False,
|
||||
f"[{MORE_TYPES_FUNCTION_OUTPUT}]", [MORE_TYPES_FUNCTION_CALL],
|
||||
id="more_types_nonstreaming"),
|
||||
pytest.param(True,
|
||||
f"[{PARAMETERLESS_FUNCTION_OUTPUT}]",
|
||||
[PARAMETERLESS_FUNCTION_CALL],
|
||||
id="parameterless_streaming"),
|
||||
pytest.param(False,
|
||||
f"[{PARAMETERLESS_FUNCTION_OUTPUT}]",
|
||||
[PARAMETERLESS_FUNCTION_CALL],
|
||||
id="parameterless_nonstreaming"),
|
||||
pytest.param(True,
|
||||
f"[{EMPTY_DICT_FUNCTION_OUTPUT}]", [EMPTY_DICT_FUNCTION_CALL],
|
||||
id="empty_dict_streaming"),
|
||||
pytest.param(False,
|
||||
f"[{EMPTY_DICT_FUNCTION_OUTPUT}]", [EMPTY_DICT_FUNCTION_CALL],
|
||||
id="empty_dict_nonstreaming"),
|
||||
pytest.param(True,
|
||||
f"[{EMPTY_LIST_FUNCTION_OUTPUT}]", [EMPTY_LIST_FUNCTION_CALL],
|
||||
id="empty_list_streaming"),
|
||||
pytest.param(False,
|
||||
f"[{EMPTY_LIST_FUNCTION_OUTPUT}]", [EMPTY_LIST_FUNCTION_CALL],
|
||||
id="empty_list_nonstreaming"),
|
||||
pytest.param(True,
|
||||
f"[{ESCAPED_STRING_FUNCTION_OUTPUT}]",
|
||||
[ESCAPED_STRING_FUNCTION_CALL],
|
||||
id="escaped_string_streaming"),
|
||||
pytest.param(False,
|
||||
f"[{ESCAPED_STRING_FUNCTION_OUTPUT}]",
|
||||
[ESCAPED_STRING_FUNCTION_CALL],
|
||||
id="escaped_string_nonstreaming"),
|
||||
pytest.param(True,
|
||||
f"[{SIMPLE_FUNCTION_OUTPUT}, {MORE_TYPES_FUNCTION_OUTPUT}]",
|
||||
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL],
|
||||
id="parallel_calls_streaming"),
|
||||
pytest.param(False,
|
||||
f"[{SIMPLE_FUNCTION_OUTPUT}, {MORE_TYPES_FUNCTION_OUTPUT}]",
|
||||
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL],
|
||||
id="parallel_calls_nonstreaming"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming, model_output, expected_tool_calls",
|
||||
TEST_CASES)
|
||||
def test_tool_call(streaming: bool, model_output: str,
|
||||
expected_tool_calls: List[FunctionCall]):
|
||||
mock_tokenizer = MagicMock()
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
|
||||
mock_tokenizer)
|
||||
|
||||
content, tool_calls = run_tool_extraction(tool_parser,
|
||||
model_output,
|
||||
streaming=streaming)
|
||||
|
||||
assert content is None
|
||||
assert len(tool_calls) == len(expected_tool_calls)
|
||||
for actual, expected in zip(tool_calls, expected_tool_calls):
|
||||
assert actual.type == "function"
|
||||
assert actual.function == expected
|
||||
|
||||
|
||||
def test_streaming_tool_call_with_large_steps():
|
||||
mock_tokenizer = MagicMock()
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
|
||||
mock_tokenizer)
|
||||
model_output_deltas = [
|
||||
"[get_weather(city='San",
|
||||
" Francisco', metric='celsius'), "
|
||||
f"{PARAMETERLESS_FUNCTION_OUTPUT}, "
|
||||
f"{EMPTY_LIST_FUNCTION_OUTPUT}]",
|
||||
]
|
||||
|
||||
reconstructor = run_tool_extraction_streaming(
|
||||
tool_parser, model_output_deltas, assert_one_tool_per_delta=False)
|
||||
|
||||
assert reconstructor.other_content == ""
|
||||
assert len(reconstructor.tool_calls) == 3
|
||||
assert reconstructor.tool_calls[0].function == SIMPLE_FUNCTION_CALL
|
||||
assert reconstructor.tool_calls[1].function == PARAMETERLESS_FUNCTION_CALL
|
||||
assert reconstructor.tool_calls[2].function == EMPTY_LIST_FUNCTION_CALL
|
||||
123
vllm-v0.6.2/tests/entrypoints/openai/tool_parsers/utils.py
Normal file
123
vllm-v0.6.2/tests/entrypoints/openai/tool_parsers/utils.py
Normal file
@@ -0,0 +1,123 @@
|
||||
from typing import Iterable, List, Tuple, Union
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaMessage,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser
|
||||
|
||||
|
||||
class StreamingToolReconstructor:
|
||||
|
||||
def __init__(self, assert_one_tool_per_delta: bool = True):
|
||||
self.tool_calls: List[ToolCall] = []
|
||||
self.other_content: str = ""
|
||||
self._assert_one_tool_per_delta = assert_one_tool_per_delta
|
||||
|
||||
def append_delta(self, delta: DeltaMessage):
|
||||
if delta.content is not None:
|
||||
self.other_content += delta.content
|
||||
else:
|
||||
assert delta.tool_calls, (
|
||||
"Streaming results should have either content or tool calls "
|
||||
"(or both)")
|
||||
if self._assert_one_tool_per_delta:
|
||||
# Note: This isn't strictly required by the API and may not be
|
||||
# possible to adhere to depending on the token space and number of
|
||||
# tokens per streamed response from the model, but it is required
|
||||
# by tool_use tests, so we enforce it here by default also.
|
||||
assert len(delta.tool_calls) < 2, (
|
||||
"Streaming should include only one tool call per update.")
|
||||
for call_delta in delta.tool_calls:
|
||||
assert call_delta.type == "function", (
|
||||
"Streaming tool calls should only emit function calls. Got "
|
||||
f"{call_delta.type}")
|
||||
current_tool_call = self.tool_calls[
|
||||
call_delta.index] if call_delta.index < len(
|
||||
self.tool_calls) else None
|
||||
if current_tool_call:
|
||||
assert (not call_delta.function.name), (
|
||||
"Streaming tool calls should emit the full function name "
|
||||
f"exactly once. Got {call_delta.function.name}")
|
||||
assert (not call_delta.id), (
|
||||
"Streaming tool calls must emit function id only once. Got "
|
||||
f"{call_delta.id}")
|
||||
assert (call_delta.index == len(self.tool_calls) - 1), (
|
||||
f"Incorrect index for tool delta. Got {call_delta.index}, "
|
||||
f"expected {len(self.tool_calls) - 1}")
|
||||
current_tool_call.function.arguments += (
|
||||
call_delta.function.arguments)
|
||||
else:
|
||||
assert call_delta.id is not None, (
|
||||
"Streaming tool calls must have an id on first appearance")
|
||||
assert call_delta.function.name is not None, (
|
||||
"Streaming tool calls must have a function name on first "
|
||||
"appearance")
|
||||
assert call_delta.index == len(self.tool_calls), (
|
||||
f"Incorrect index for tool delta. Got {call_delta.index}, "
|
||||
f"expected {len(self.tool_calls)}")
|
||||
self.tool_calls.append(
|
||||
ToolCall(id=call_delta.id,
|
||||
function=FunctionCall(
|
||||
name=call_delta.function.name,
|
||||
arguments=call_delta.function.arguments
|
||||
or "")))
|
||||
|
||||
|
||||
def run_tool_extraction(
|
||||
tool_parser: ToolParser,
|
||||
model_output: str,
|
||||
request: Union[ChatCompletionRequest, None] = None,
|
||||
streaming: bool = False,
|
||||
assert_one_tool_per_delta: bool = True,
|
||||
) -> Tuple[Union[str, None], List[ToolCall]]:
|
||||
if streaming:
|
||||
reconstructor = run_tool_extraction_streaming(
|
||||
tool_parser,
|
||||
model_output,
|
||||
request,
|
||||
assert_one_tool_per_delta=assert_one_tool_per_delta)
|
||||
return reconstructor.other_content or None, reconstructor.tool_calls
|
||||
else:
|
||||
extracted = run_tool_extraction_nonstreaming(tool_parser, model_output,
|
||||
request)
|
||||
assert extracted.tools_called == bool(extracted.tool_calls)
|
||||
return extracted.content, extracted.tool_calls
|
||||
|
||||
|
||||
def run_tool_extraction_nonstreaming(
|
||||
tool_parser: ToolParser,
|
||||
model_output: str,
|
||||
request: Union[ChatCompletionRequest, None] = None
|
||||
) -> ExtractedToolCallInformation:
|
||||
request = request or ChatCompletionRequest(messages=[], model="test-model")
|
||||
return tool_parser.extract_tool_calls(model_output, request)
|
||||
|
||||
|
||||
def run_tool_extraction_streaming(
|
||||
tool_parser: ToolParser,
|
||||
model_deltas: Iterable[str],
|
||||
request: Union[ChatCompletionRequest, None] = None,
|
||||
assert_one_tool_per_delta: bool = True,
|
||||
) -> StreamingToolReconstructor:
|
||||
request = request or ChatCompletionRequest(messages=[], model="test-model")
|
||||
reconstructor = StreamingToolReconstructor(
|
||||
assert_one_tool_per_delta=assert_one_tool_per_delta)
|
||||
previous_text = ""
|
||||
previous_tokens: List[int] = []
|
||||
for delta in model_deltas:
|
||||
token_delta = [
|
||||
tool_parser.vocab.get(token)
|
||||
for token in tool_parser.model_tokenizer.tokenize(delta)
|
||||
if token in tool_parser.vocab
|
||||
]
|
||||
current_text = previous_text + delta
|
||||
current_tokens = previous_tokens + token_delta
|
||||
delta_message = tool_parser.extract_tool_calls_streaming(
|
||||
previous_text, current_text, delta, previous_tokens,
|
||||
current_tokens, token_delta, request)
|
||||
if delta_message is not None:
|
||||
reconstructor.append_delta(delta_message)
|
||||
previous_text = current_text
|
||||
previous_tokens = current_tokens
|
||||
return reconstructor
|
||||
Reference in New Issue
Block a user