Sync from v0.13

This commit is contained in:
2026-01-19 10:38:50 +08:00
parent b2ef04d792
commit 5aef6c175a
3714 changed files with 854317 additions and 89342 deletions

View File

@@ -0,0 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from transformers import AutoTokenizer
from vllm.tokenizers import TokenizerLike
@pytest.fixture(scope="function")
def default_tokenizer() -> TokenizerLike:
return AutoTokenizer.from_pretrained("gpt2")

View File

@@ -0,0 +1,176 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import pytest
from tests.entrypoints.openai.tool_parsers.utils import (
run_tool_extraction,
run_tool_extraction_streaming,
)
from vllm.entrypoints.openai.protocol import FunctionCall
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser, ToolParserManager
SIMPLE_ARGS_DICT = {
"action": "create",
"id": "preferences",
}
SIMPLE_FUNCTION_JSON = json.dumps(
{
"name": "manage_user_memory",
"arguments": SIMPLE_ARGS_DICT,
},
ensure_ascii=False,
)
SIMPLE_FUNCTION_OUTPUT = "function call" + SIMPLE_FUNCTION_JSON
SIMPLE_FUNCTION_CALL = FunctionCall(
name="manage_user_memory",
arguments=json.dumps(SIMPLE_ARGS_DICT, ensure_ascii=False),
)
PARAMETERLESS_FUNCTION_JSON = json.dumps(
{
"name": "manage_user_memory",
"arguments": {},
},
ensure_ascii=False,
)
PARAMETERLESS_FUNCTION_OUTPUT = "function call" + PARAMETERLESS_FUNCTION_JSON
PARAMETERLESS_FUNCTION_CALL = FunctionCall(
name="manage_user_memory",
arguments=json.dumps({}, ensure_ascii=False),
)
COMPLEX_ARGS_DICT = {
"action": "create",
"id": "preferences",
"content": {
"short_answers": True,
"hate_emojis": True,
"english_ui": False,
"russian_math_explanations": True,
},
}
COMPLEX_FUNCTION_JSON = json.dumps(
{
"name": "manage_user_memory",
"arguments": COMPLEX_ARGS_DICT,
},
ensure_ascii=False,
)
COMPLEX_FUNCTION_OUTPUT = "function call" + COMPLEX_FUNCTION_JSON
COMPLEX_FUNCTION_CALL = FunctionCall(
name="manage_user_memory",
arguments=json.dumps(COMPLEX_ARGS_DICT, ensure_ascii=False),
)
@pytest.mark.parametrize("streaming", [True, False])
def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")(
default_tokenizer
)
model_output = "How can I help you today?"
content, tool_calls = run_tool_extraction(
tool_parser, model_output, streaming=streaming
)
assert content == model_output
assert len(tool_calls) == 0
TEST_CASES = [
pytest.param(
True,
SIMPLE_FUNCTION_OUTPUT,
[SIMPLE_FUNCTION_CALL],
None,
id="simple_streaming",
),
pytest.param(
False,
SIMPLE_FUNCTION_OUTPUT,
[SIMPLE_FUNCTION_CALL],
None,
id="simple_nonstreaming",
),
pytest.param(
True,
PARAMETERLESS_FUNCTION_OUTPUT,
[PARAMETERLESS_FUNCTION_CALL],
None,
id="parameterless_streaming",
),
pytest.param(
False,
PARAMETERLESS_FUNCTION_OUTPUT,
[PARAMETERLESS_FUNCTION_CALL],
None,
id="parameterless_nonstreaming",
),
pytest.param(
True,
COMPLEX_FUNCTION_OUTPUT,
[COMPLEX_FUNCTION_CALL],
None,
id="complex_streaming",
),
pytest.param(
False,
COMPLEX_FUNCTION_OUTPUT,
[COMPLEX_FUNCTION_CALL],
None,
id="complex_nonstreaming",
),
]
@pytest.mark.parametrize(
"streaming, model_output, expected_tool_calls, expected_content", TEST_CASES
)
def test_tool_call(
streaming: bool,
model_output: str,
expected_tool_calls: list[FunctionCall],
expected_content: str | None,
default_tokenizer: TokenizerLike,
):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")(
default_tokenizer
)
content, tool_calls = run_tool_extraction(
tool_parser, model_output, streaming=streaming
)
assert content == expected_content
assert len(tool_calls) == len(expected_tool_calls)
for actual, expected in zip(tool_calls, expected_tool_calls):
assert actual.type == "function"
assert actual.function.name == expected.name
actual_args = json.loads(actual.function.arguments)
expected_args = json.loads(expected.arguments)
assert actual_args == expected_args
def test_streaming_tool_call_with_large_steps(default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")(
default_tokenizer
)
model_output_deltas = [
"function call",
COMPLEX_FUNCTION_JSON[:40],
COMPLEX_FUNCTION_JSON[40:],
]
reconstructor = run_tool_extraction_streaming(
tool_parser,
model_output_deltas,
assert_one_tool_per_delta=False,
)
assert len(reconstructor.tool_calls) == 1
call = reconstructor.tool_calls[0]
assert call.type == "function"
assert call.function.name == "manage_user_memory"
args_dict = json.loads(call.function.arguments)
assert args_dict == COMPLEX_ARGS_DICT

View File

@@ -0,0 +1,460 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import pytest
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.hermes_tool_parser import Hermes2ProToolParser
from ....utils import RemoteOpenAIServer
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
LORA_MODEL = "minpeter/LoRA-Llama-3.2-1B-tool-vllm-ci"
SERVER_ARGS = [
"--enforce-eager",
"--enable-auto-tool-choice",
"--tool-call-parser",
"hermes",
"--enable-lora",
"--lora-modules",
f"{LORA_MODEL}={LORA_MODEL}",
"--tokenizer",
f"{LORA_MODEL}",
]
TOOLS = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
},
},
}
]
PRODUCT_TOOLS = [
{
"type": "function",
"function": {
"name": "get_product_info",
"description": "Get detailed information of a product based on its "
"product ID.",
"parameters": {
"type": "object",
"properties": {
"inserted": {
"type": "boolean",
"description": "inserted.",
},
"product_id": {
"type": "integer",
"description": "The product ID of the product.",
},
},
"required": ["product_id", "inserted"],
},
},
}
]
MESSAGES = [{"role": "user", "content": "What's the weather like in Boston?"}]
PRODUCT_MESSAGES = [
{
"role": "user",
"content": "Hi! Do you have any detailed information about the product id "
"7355608 and inserted true?",
}
]
@pytest.mark.asyncio
async def test_non_streaming_tool_call():
"""Test tool call in non-streaming mode."""
with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server:
client = server.get_async_client()
response = await client.chat.completions.create(
model=LORA_MODEL,
messages=MESSAGES,
tools=TOOLS,
tool_choice="auto",
temperature=0.0,
)
assert response.choices
choice = response.choices[0]
message = choice.message
assert choice.finish_reason == "tool_calls"
assert message.tool_calls is not None
tool_call = message.tool_calls[0]
assert tool_call.type == "function"
assert tool_call.function.name == "get_current_weather"
arguments = json.loads(tool_call.function.arguments)
assert "location" in arguments
assert "Boston" in arguments["location"]
print("\n[Non-Streaming Test Passed]")
print(f"Tool Call: {tool_call.function.name}")
print(f"Arguments: {arguments}")
@pytest.mark.asyncio
async def test_streaming_tool_call():
"""Test tool call in streaming mode."""
with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server:
client = server.get_async_client()
stream = await client.chat.completions.create(
model=LORA_MODEL,
messages=MESSAGES,
tools=TOOLS,
tool_choice="auto",
temperature=0.0,
stream=True,
)
tool_call_chunks = {}
async for chunk in stream:
if not chunk.choices:
continue
delta = chunk.choices[0].delta
if not delta or not delta.tool_calls:
continue
for tool_chunk in delta.tool_calls:
index = tool_chunk.index
if index not in tool_call_chunks:
tool_call_chunks[index] = {"name": "", "arguments": ""}
if tool_chunk.function.name:
tool_call_chunks[index]["name"] += tool_chunk.function.name
if tool_chunk.function.arguments:
tool_call_chunks[index]["arguments"] += (
tool_chunk.function.arguments
)
assert len(tool_call_chunks) == 1
reconstructed_tool_call = tool_call_chunks[0]
assert reconstructed_tool_call["name"] == "get_current_weather"
arguments = json.loads(reconstructed_tool_call["arguments"])
assert "location" in arguments
assert "Boston" in arguments["location"]
print("\n[Streaming Test Passed]")
print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}")
print(f"Reconstructed Arguments: {arguments}")
@pytest.mark.asyncio
async def test_non_streaming_product_tool_call():
"""Test tool call integer and boolean parameters in non-streaming mode."""
with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server:
client = server.get_async_client()
response = await client.chat.completions.create(
model=LORA_MODEL,
messages=PRODUCT_MESSAGES,
tools=PRODUCT_TOOLS,
tool_choice="auto",
temperature=0.66,
)
assert response.choices
choice = response.choices[0]
message = choice.message
assert choice.finish_reason == "tool_calls"
assert message.tool_calls is not None
tool_call = message.tool_calls[0]
assert tool_call.type == "function"
assert tool_call.function.name == "get_product_info"
arguments = json.loads(tool_call.function.arguments)
assert "product_id" in arguments
assert "inserted" in arguments
product_id = arguments.get("product_id")
inserted = arguments.get("inserted")
assert isinstance(product_id, int)
assert product_id == 7355608
assert isinstance(inserted, bool)
assert inserted is True
print("\n[Non-Streaming Product Test Passed]")
print(f"Tool Call: {tool_call.function.name}")
print(f"Arguments: {arguments}")
@pytest.mark.asyncio
async def test_streaming_product_tool_call():
"""Test tool call integer and boolean parameters in streaming mode."""
with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server:
client = server.get_async_client()
stream = await client.chat.completions.create(
model=LORA_MODEL,
messages=PRODUCT_MESSAGES,
tools=PRODUCT_TOOLS,
tool_choice="auto",
temperature=0.66,
stream=True,
)
tool_call_chunks = {}
async for chunk in stream:
if not chunk.choices:
continue
delta = chunk.choices[0].delta
if not delta or not delta.tool_calls:
continue
for tool_chunk in delta.tool_calls:
index = tool_chunk.index
if index not in tool_call_chunks:
tool_call_chunks[index] = {"name": "", "arguments": ""}
if tool_chunk.function.name:
tool_call_chunks[index]["name"] += tool_chunk.function.name
if tool_chunk.function.arguments:
tool_call_chunks[index]["arguments"] += (
tool_chunk.function.arguments
)
assert len(tool_call_chunks) == 1
reconstructed_tool_call = tool_call_chunks[0]
assert reconstructed_tool_call["name"] == "get_product_info"
arguments = json.loads(reconstructed_tool_call["arguments"])
assert "product_id" in arguments
assert "inserted" in arguments
# Handle type coercion for streaming test as well
product_id = arguments.get("product_id")
inserted = arguments.get("inserted")
assert isinstance(product_id, int)
assert product_id == 7355608
assert isinstance(inserted, bool)
assert inserted is True
print("\n[Streaming Product Test Passed]")
print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}")
print(f"Reconstructed Arguments: {arguments}")
@pytest.fixture
def qwen_tokenizer() -> TokenizerLike:
from vllm.tokenizers import get_tokenizer
return get_tokenizer("Qwen/Qwen3-32B")
@pytest.fixture
def hermes_parser(qwen_tokenizer: TokenizerLike) -> Hermes2ProToolParser:
return Hermes2ProToolParser(qwen_tokenizer)
@pytest.fixture
def any_chat_request() -> ChatCompletionRequest:
return ChatCompletionRequest(
seed=42,
model="Qwen/Qwen3-32B",
messages=[],
)
def test_hermes_parser_streaming_just_forward_text(
qwen_tokenizer: TokenizerLike,
hermes_parser: Hermes2ProToolParser,
any_chat_request: ChatCompletionRequest,
) -> None:
text = """This is some prior text that has nothing to do with tool calling."""
tokens = qwen_tokenizer.encode(text)
previous_text = ""
delta_messages = []
for token in tokens:
delta_text = qwen_tokenizer.decode([token])
current_text = previous_text + delta_text
delta = hermes_parser.extract_tool_calls_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=any_chat_request,
)
previous_text = current_text
delta_messages.append(delta)
for delta in delta_messages:
assert delta is not None
assert not delta.tool_calls
print(delta_messages)
assert "".join([delta.content for delta in delta_messages]) == text
def test_hermes_parser_streaming_failure_case_bug_19056(
qwen_tokenizer: TokenizerLike,
hermes_parser: Hermes2ProToolParser,
any_chat_request: ChatCompletionRequest,
) -> None:
text = """<tool_call>
{"name": "final_answer", "arguments": {"trigger": true}}
</tool_call>"""
tokens = qwen_tokenizer.encode(text)
previous_text = ""
delta_messages = []
for token in tokens:
text = qwen_tokenizer.decode([token])
current_text = previous_text + text
delta = hermes_parser.extract_tool_calls_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=text,
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=any_chat_request,
)
previous_text = current_text
if delta is not None:
delta_messages.append(delta)
assert delta_messages[0].tool_calls[0].function.name == "final_answer"
tool_call_args = "".join(
delta.tool_calls[0].function.arguments or "" for delta in delta_messages
)
assert tool_call_args == '{"trigger": true}'
def test_hermes_parser_streaming(
qwen_tokenizer: TokenizerLike,
hermes_parser: Hermes2ProToolParser,
any_chat_request: ChatCompletionRequest,
) -> None:
text = '<tool_call>\
{"name": "get_current_temperature",\
"arguments": {"location":\
"San Francisco, California, United States", "unit": "celsius"}}\
</tool_call>'
tokens = qwen_tokenizer.encode(text)
previous_text = ""
delta_messages = []
for token in tokens:
text = qwen_tokenizer.decode([token])
current_text = previous_text + text
delta = hermes_parser.extract_tool_calls_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=text,
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=any_chat_request,
)
previous_text = current_text
if delta is not None:
delta_messages.append(delta)
print(delta_messages)
assert delta_messages[0].tool_calls[0].function.name == "get_current_temperature"
tool_call_args = "".join(
delta.tool_calls[0].function.arguments or "" for delta in delta_messages
)
assert tool_call_args == (
'{"location":"San Francisco, California, United States", "unit": "celsius"}'
)
def test_hermes_parser_non_streaming_no_tool_call(
hermes_parser: Hermes2ProToolParser,
any_chat_request: ChatCompletionRequest,
) -> None:
text = """This is not a tool call."""
tool_call = hermes_parser.extract_tool_calls(
model_output=text,
request=any_chat_request,
)
assert tool_call is not None
assert not tool_call.tools_called
def test_hermes_parser_non_streaming_tool_call_between_tags(
hermes_parser: Hermes2ProToolParser,
any_chat_request: ChatCompletionRequest,
) -> None:
text = """<tool_call>
{"name": "final_answer", "arguments": {"trigger": true}}
</tool_call>"""
tool_call = hermes_parser.extract_tool_calls(
model_output=text,
request=any_chat_request,
)
assert tool_call is not None
assert tool_call.tools_called
assert tool_call.tool_calls[0].function.name == "final_answer"
assert tool_call.tool_calls[0].function.arguments == '{"trigger": true}'
def test_hermes_parser_non_streaming_tool_call_until_eos(
hermes_parser: Hermes2ProToolParser,
any_chat_request: ChatCompletionRequest,
) -> None:
text = """<tool_call>
{"name": "final_answer", "arguments": {"trigger": true}}"""
tool_call = hermes_parser.extract_tool_calls(
model_output=text,
request=any_chat_request,
)
assert tool_call is not None
assert tool_call.tools_called
assert tool_call.tool_calls[0].function.name == "final_answer"
assert tool_call.tool_calls[0].function.arguments == '{"trigger": true}'
def test_hermes_parser_non_streaming_tool_call_invalid_json(
hermes_parser: Hermes2ProToolParser,
any_chat_request: ChatCompletionRequest,
) -> None:
# Missing closing brace to trigger exception
text = """<tool_call>
{"name": "final_answer", "arguments": {"trigger": true}"""
tool_call = hermes_parser.extract_tool_calls(
model_output=text,
request=any_chat_request,
)
assert tool_call is not None
assert not tool_call.tools_called

View File

@@ -0,0 +1,179 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
import json
from unittest.mock import MagicMock
import pytest
from tests.entrypoints.openai.tool_parsers.utils import (
run_tool_extraction,
run_tool_extraction_streaming,
)
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
from vllm.tool_parsers import ToolParser, ToolParserManager
def make_tool_call(name, arguments):
return ToolCall(
type="function",
function=FunctionCall(name=name, arguments=json.dumps(arguments)),
)
# TODO: add reason prefix and suffix.
@pytest.mark.parametrize(
"model_output,expected_tool_calls,expected_content",
[
# No tool call
("How can I help you today?", [], "How can I help you today?"),
# Single tool call, no content
(
'<tool_calls>[{"name": "get_weather", "arguments": {"city": "San Francisco", "metric": "celsius"}}]</tool_calls>', # noqa: E501
[
make_tool_call(
"get_weather", {"city": "San Francisco", "metric": "celsius"}
)
],
None,
),
# Multiple tool calls
(
'<tool_calls>[{"name": "get_weather", "arguments": {"city": "San Francisco", "metric": "celsius"}}, {"name": "register_user", "arguments": {"name": "John Doe", "age": 37, "address": {"city": "San Francisco", "state": "CA"}, "role": null, "passed_test": true, "aliases": ["John", "Johnny"]}}]</tool_calls>', # noqa: E501
[
make_tool_call(
"get_weather", {"city": "San Francisco", "metric": "celsius"}
),
make_tool_call(
"register_user",
{
"name": "John Doe",
"age": 37,
"address": {"city": "San Francisco", "state": "CA"},
"role": None,
"passed_test": True,
"aliases": ["John", "Johnny"],
},
),
],
None,
),
# Content before tool call
(
'I will call the tool now. <tool_calls>[{"name": "get_weather", "arguments": {"city": "Boston"}}]</tool_calls>', # noqa: E501
[make_tool_call("get_weather", {"city": "Boston"})],
"I will call the tool now. ",
),
# Content after tool call (should be stripped)
(
'<tool_calls>[{"name": "get_weather", "arguments": {"city": "Seattle"}}]</tool_calls>\nThank you!', # noqa: E501
[make_tool_call("get_weather", {"city": "Seattle"})],
None,
),
(
'<tool_calls>[{"name": "complex_tool", "arguments": {"level1": {"level2": {"level3": {"value": 123}}}}}]</tool_calls>',
[
make_tool_call(
"complex_tool", {"level1": {"level2": {"level3": {"value": 123}}}}
)
],
None,
),
],
)
def test_hunyuan_a13b_tool_parser_extract(
model_output, expected_tool_calls, expected_content
):
mock_tokenizer = MagicMock()
tool_parser: ToolParser = ToolParserManager.get_tool_parser("hunyuan_a13b")(
mock_tokenizer
)
content, tool_calls = run_tool_extraction(
tool_parser, model_output, streaming=False
)
# align the random id.
for idx in range(len(tool_calls)):
tool_calls[idx].id = expected_tool_calls[idx].id
assert tool_calls == expected_tool_calls
assert content == expected_content
# Streaming test: simulate incremental output
@pytest.mark.parametrize(
"model_deltas,expected_tool_calls",
[
(
[
'<tool_calls>[{"name": "get_weather", ',
'"arguments": {"city": "San Francisco", ',
'"metric": "celsius"}}]',
"</tool_calls>",
],
[
make_tool_call(
"get_weather", {"city": "San Francisco", "metric": "celsius"}
)
],
),
(
[
'<tool_calls>[{"name":',
' "get_weather",',
' "arguments":',
' {"city": "Boston"}',
"}]",
"</tool_calls>",
],
[make_tool_call("get_weather", {"city": "Boston"})],
),
(
[
"",
'<tool_calls>[{"name":',
' "get_weather",',
' "arguments":',
' {"city": "Boston"}',
"}]",
"</tool_calls>",
"\n</answer>",
],
[make_tool_call("get_weather", {"city": "Boston"})],
),
pytest.param(
[
'<tool_calls>[{"name": "complex_tool",',
' "arguments": ',
' {"level1": {"level2": ',
'{"level3": {"value": 123}}}}}',
"]</tool_calls>",
],
[
make_tool_call(
"complex_tool", {"level1": {"level2": {"level3": {"value": 123}}}}
)
],
marks=pytest.mark.xfail(
reason="stream parsing not support nested json yet."
),
),
],
)
def test_hunyuan_a13b_tool_parser_streaming(model_deltas, expected_tool_calls):
mock_tokenizer = MagicMock()
tool_parser: ToolParser = ToolParserManager.get_tool_parser("hunyuan_a13b")(
mock_tokenizer
)
reconstructor = run_tool_extraction_streaming(
tool_parser, model_deltas, assert_one_tool_per_delta=False
)
# align the random id.
for idx in range(len(reconstructor.tool_calls)):
reconstructor.tool_calls[idx].id = expected_tool_calls[idx].id
assert reconstructor.tool_calls == expected_tool_calls

View File

@@ -0,0 +1,262 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock, patch
import pytest
from vllm.entrypoints.openai.protocol import ExtractedToolCallInformation
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.llama_tool_parser import Llama3JsonToolParser
@pytest.fixture
def parser(default_tokenizer: TokenizerLike):
return Llama3JsonToolParser(default_tokenizer)
def test_extract_tool_calls_simple(parser):
# Test with a simple tool call
model_output = (
'Here is the result: {"name": "getOpenIncidentsTool", '
'"parameters": {}} Would you like to know more?'
)
result = parser.extract_tool_calls(model_output, None)
assert isinstance(result, ExtractedToolCallInformation)
assert result.tools_called is True
assert len(result.tool_calls) == 1
assert result.tool_calls[0].type == "function"
assert result.tool_calls[0].function.name == "getOpenIncidentsTool"
assert result.tool_calls[0].function.arguments == "{}"
assert result.content is None
def test_extract_tool_calls_with_arguments(parser):
# Test with a tool call that has arguments
model_output = (
'{"name": "searchTool", "parameters": {"query": "test query", "limit": 10}}'
)
result = parser.extract_tool_calls(model_output, None)
assert result.tools_called is True
assert len(result.tool_calls) == 1
assert result.tool_calls[0].function.name == "searchTool"
assert '"query": "test query"' in result.tool_calls[0].function.arguments
assert '"limit": 10' in result.tool_calls[0].function.arguments
def test_extract_tool_calls_no_json(parser):
# Test with text that doesn't contain a JSON object
model_output = "This is just some text without any tool calls"
result = parser.extract_tool_calls(model_output, None)
assert result.tools_called is False
assert len(result.tool_calls) == 0
assert result.content == model_output
def test_extract_tool_calls_invalid_json(parser):
# Test with invalid JSON
model_output = '{"name": "invalidTool", "parameters": {invalid json}'
result = parser.extract_tool_calls(model_output, None)
assert result.tools_called is False
assert len(result.tool_calls) == 0
assert result.content == model_output
def test_extract_tool_calls_with_arguments_key(parser):
# Test with a tool call that uses "arguments" instead of "parameters"
model_output = '{"name": "searchTool", "arguments": {"query": "test"}}'
result = parser.extract_tool_calls(model_output, None)
assert result.tools_called is True
assert len(result.tool_calls) == 1
assert result.tool_calls[0].function.name == "searchTool"
assert '"query": "test"' in result.tool_calls[0].function.arguments
def test_extract_tool_calls_multiple_json(parser):
# Test with multiple JSONs separated by semicolons
model_output = (
'{"name": "searchTool", "parameters": {"query": "test1"}}; '
'{"name": "getOpenIncidentsTool", "parameters": {}}; '
'{"name": "searchTool", "parameters": {"query": "test2"}}'
)
result = parser.extract_tool_calls(model_output, None)
assert result.tools_called is True
assert len(result.tool_calls) == 3
# Check first tool call
assert result.tool_calls[0].function.name == "searchTool"
assert '"query": "test1"' in result.tool_calls[0].function.arguments
# Check second tool call
assert result.tool_calls[1].function.name == "getOpenIncidentsTool"
assert result.tool_calls[1].function.arguments == "{}"
# Check third tool call
assert result.tool_calls[2].function.name == "searchTool"
assert '"query": "test2"' in result.tool_calls[2].function.arguments
def test_extract_tool_calls_multiple_json_with_whitespace(parser):
# Test with multiple JSONs separated by semicolons and extra whitespace
model_output = (
'{"name": "searchTool", "parameters": {"query": "test1"}} ; '
'{"name": "getOpenIncidentsTool", "parameters": {}} ; '
'{"name": "searchTool", "parameters": {"query": "test2"}}'
)
result = parser.extract_tool_calls(model_output, None)
assert result.tools_called is True
assert len(result.tool_calls) == 3
assert result.tool_calls[0].function.name == "searchTool"
assert result.tool_calls[1].function.name == "getOpenIncidentsTool"
assert result.tool_calls[2].function.name == "searchTool"
def test_extract_tool_calls_multiple_json_with_surrounding_text(parser):
# Test with multiple JSONs and surrounding text
model_output = (
"Here are the results: "
'{"name": "searchTool", "parameters": {"query": "test1"}}; '
'{"name": "getOpenIncidentsTool", "parameters": {}}; '
'{"name": "searchTool", "parameters": {"query": "test2"}} '
"Would you like to know more?"
)
result = parser.extract_tool_calls(model_output, None)
assert result.tools_called is True
assert len(result.tool_calls) == 3
assert result.tool_calls[0].function.name == "searchTool"
assert result.tool_calls[1].function.name == "getOpenIncidentsTool"
assert result.tool_calls[2].function.name == "searchTool"
def test_extract_tool_calls_deeply_nested_json(parser):
# Test with deeply nested JSON parameters (5 levels)
model_output = (
'{"name": "complexTool", '
'"parameters": {'
'"level1": {'
'"level2": {'
'"level3": {'
'"level4": {'
'"value": "deep"'
"}}}}}}"
)
result = parser.extract_tool_calls(model_output, None)
assert result.tools_called is True
assert len(result.tool_calls) == 1
assert result.tool_calls[0].function.name == "complexTool"
# Verify the nested structure is preserved in the arguments
import json
args = json.loads(result.tool_calls[0].function.arguments)
assert args["level1"]["level2"]["level3"]["level4"]["value"] == "deep"
def test_extract_tool_calls_multiple_with_deep_nesting(parser):
# Test with multiple tool calls where some have deeply nested parameters
model_output = (
'{"name": "simpleTool", "parameters": {"value": "test"}}; '
'{"name": "complexTool", "parameters": '
'{"config": {"database": {"connection": {"pool": {"size": 10}}}}}}'
)
result = parser.extract_tool_calls(model_output, None)
assert result.tools_called is True
assert len(result.tool_calls) == 2
# Check first tool call
assert result.tool_calls[0].function.name == "simpleTool"
import json
args0 = json.loads(result.tool_calls[0].function.arguments)
assert args0["value"] == "test"
# Check second tool call with deep nesting
assert result.tool_calls[1].function.name == "complexTool"
args1 = json.loads(result.tool_calls[1].function.arguments)
assert args1["config"]["database"]["connection"]["pool"]["size"] == 10
def test_extract_tool_calls_with_quotes_and_brackets_in_string(parser):
# Test with quotes and brackets inside quoted string values
model_output = (
'{"name": "searchTool", '
'"parameters": {'
'"query": "test {value} [complex]",'
'"nested": {"inner": "more {brackets}"}'
"}}"
)
result = parser.extract_tool_calls(model_output, None)
assert result.tools_called is True
assert len(result.tool_calls) == 1
assert result.tool_calls[0].function.name == "searchTool"
# Verify the string values are preserved including brackets and quotes
import json
args = json.loads(result.tool_calls[0].function.arguments)
assert args["query"] == "test {value} [complex]"
assert args["nested"]["inner"] == "more {brackets}"
def test_extract_tool_calls_with_escaped_quotes_in_nested_json(parser):
# Test with escaped quotes in deeply nested JSON
model_output = (
'{"name": "parserTool", "parameters": {"text": "He said \\"Hello {world}\\""}}'
)
result = parser.extract_tool_calls(model_output, None)
assert result.tools_called is True
assert len(result.tool_calls) == 1
assert result.tool_calls[0].function.name == "parserTool"
# Verify escaped quotes are preserved
import json
args = json.loads(result.tool_calls[0].function.arguments)
assert args["text"] == 'He said "Hello {world}"'
def test_extract_tool_calls_missing_name_key(parser):
# Test that missing "name" key returns content
model_output = '{"parameters": {}}'
result = parser.extract_tool_calls(model_output, None)
assert result.tools_called is False
assert len(result.tool_calls) == 0
assert result.content == model_output
def test_extract_tool_calls_missing_parameters_and_arguments_key(parser):
# Test that missing both "parameters" and "arguments" keys returns content
model_output = '{"name": "toolWithoutParams"}'
result = parser.extract_tool_calls(model_output, None)
assert result.tools_called is False
assert len(result.tool_calls) == 0
assert result.content == model_output
def test_regex_timeout_handling(parser):
"""Test regex timeout is handled gracefully"""
fake_problematic_input = "{hello world[A(A=" + "\t)A(A=,\t" * 2
# create a mock regex that raises TimeoutError
mock_regex = MagicMock()
mock_regex.finditer.side_effect = TimeoutError("Regex timeout")
with patch.object(parser, "tool_call_start_regex", mock_regex):
result = parser.extract_tool_calls(fake_problematic_input, None)
# should treat as regular text when regex times out
assert result.content == fake_problematic_input
assert result.tools_called is False
assert len(result.tool_calls) == 0
mock_regex.finditer.assert_called_once()

View File

@@ -0,0 +1,269 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock, patch
import pytest
from tests.entrypoints.openai.tool_parsers.utils import (
run_tool_extraction,
run_tool_extraction_streaming,
)
from vllm.entrypoints.openai.protocol import FunctionCall
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser, ToolParserManager
# Test cases similar to pythonic parser but with Llama4 specific format
SIMPLE_FUNCTION_OUTPUT = "[get_weather(city='LA', metric='C')]"
SIMPLE_FUNCTION_CALL = FunctionCall(
name="get_weather",
arguments='{"city": "LA", "metric": "C"}',
)
MORE_TYPES_FUNCTION_OUTPUT = (
"[register_user(name='Doe', "
"age=9, "
"address={'city': 'LA', 'state': 'CA'}, "
"role=None, "
"passed_test=True, "
"aliases=['John', 'Johnny'])]"
)
MORE_TYPES_FUNCTION_CALL = FunctionCall(
name="register_user",
arguments='{"name": "Doe", '
'"age": 9, '
'"address": {"city": "LA", "state": "CA"}, '
'"role": null, '
'"passed_test": true, '
'"aliases": ["John", "Johnny"]}',
)
PARAMETERLESS_FUNCTION_OUTPUT = "[get_weather()]"
PARAMETERLESS_FUNCTION_CALL = FunctionCall(
name="get_weather",
arguments="{}",
)
EMPTY_DICT_FUNCTION_OUTPUT = "[do_something_cool(additional_data={})]"
EMPTY_DICT_FUNCTION_CALL = FunctionCall(
name="do_something_cool",
arguments='{"additional_data": {}}',
)
EMPTY_LIST_FUNCTION_OUTPUT = "[do_something_cool(steps=[])]"
EMPTY_LIST_FUNCTION_CALL = FunctionCall(
name="do_something_cool",
arguments='{"steps": []}',
)
ESCAPED_STRING_FUNCTION_OUTPUT = (
r"[get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')]"
)
ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
name="get_weather",
arguments='{"city": "Martha\'s Vineyard", "metric": "\\"cool units\\""}',
)
PYTHON_TAG_FUNCTION_OUTPUT = (
"<|python_start|>[get_weather(city='LA', metric='C')]<|python_end|>"
)
@pytest.mark.parametrize("streaming", [True, False])
def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
default_tokenizer
)
model_output = "How can I help you today?"
content, tool_calls = run_tool_extraction(
tool_parser, model_output, streaming=streaming
)
assert content == model_output
assert len(tool_calls) == 0
test_str = "<|python_start|>"
test_str += "[get_weather(city='LA', metric='C'),"
test_str += "register_user(name='Doe', age=9)]"
TEST_CASES = [
pytest.param(
True,
ESCAPED_STRING_FUNCTION_OUTPUT,
[ESCAPED_STRING_FUNCTION_CALL],
id="simple_streaming",
),
pytest.param(
False, SIMPLE_FUNCTION_OUTPUT, [SIMPLE_FUNCTION_CALL], id="simple_nonstreaming"
),
pytest.param(
True,
MORE_TYPES_FUNCTION_OUTPUT,
[MORE_TYPES_FUNCTION_CALL],
id="more_types_streaming",
),
pytest.param(
False,
MORE_TYPES_FUNCTION_OUTPUT,
[MORE_TYPES_FUNCTION_CALL],
id="more_types_nonstreaming",
),
pytest.param(
True,
PARAMETERLESS_FUNCTION_OUTPUT,
[PARAMETERLESS_FUNCTION_CALL],
id="parameterless_streaming",
),
pytest.param(
False,
PARAMETERLESS_FUNCTION_OUTPUT,
[PARAMETERLESS_FUNCTION_CALL],
id="parameterless_nonstreaming",
),
pytest.param(
True,
EMPTY_DICT_FUNCTION_OUTPUT,
[EMPTY_DICT_FUNCTION_CALL],
id="empty_dict_streaming",
),
pytest.param(
False,
EMPTY_DICT_FUNCTION_OUTPUT,
[EMPTY_DICT_FUNCTION_CALL],
id="empty_dict_nonstreaming",
),
pytest.param(
True,
EMPTY_LIST_FUNCTION_OUTPUT,
[EMPTY_LIST_FUNCTION_CALL],
id="empty_list_streaming",
),
pytest.param(
False,
EMPTY_LIST_FUNCTION_OUTPUT,
[EMPTY_LIST_FUNCTION_CALL],
id="empty_list_nonstreaming",
),
pytest.param(
True,
ESCAPED_STRING_FUNCTION_OUTPUT,
[ESCAPED_STRING_FUNCTION_CALL],
id="escaped_string_streaming",
),
pytest.param(
False,
ESCAPED_STRING_FUNCTION_OUTPUT,
[ESCAPED_STRING_FUNCTION_CALL],
id="escaped_string_nonstreaming",
),
pytest.param(
True,
"[get_weather(city='LA',metric='C'),register_user(name='Doe',age=9)]",
[
SIMPLE_FUNCTION_CALL,
FunctionCall(name="register_user", arguments='{"name": "Doe", "age": 9}'),
],
id="parallel_calls_streaming",
),
pytest.param(
False,
"[get_weather(city='LA',metric='C'),register_user(name='Doe',age=9)]",
[
SIMPLE_FUNCTION_CALL,
FunctionCall(name="register_user", arguments='{"name": "Doe", "age": 9}'),
],
id="parallel_calls_nonstreaming",
),
pytest.param(
True,
PYTHON_TAG_FUNCTION_OUTPUT,
[SIMPLE_FUNCTION_CALL],
id="python_tag_streaming",
),
pytest.param(
False,
PYTHON_TAG_FUNCTION_OUTPUT,
[SIMPLE_FUNCTION_CALL],
id="python_tag_nonstreaming",
),
pytest.param(
True,
test_str,
[
SIMPLE_FUNCTION_CALL,
FunctionCall(name="register_user", arguments='{"name": "Doe", "age": 9}'),
],
id="parallel_calls_streaming",
),
pytest.param(
False,
"<|python_start|>[get_weather(city='LA', metric='C'), "
+ "register_user(name='Doe', age=9)]",
[
SIMPLE_FUNCTION_CALL,
FunctionCall(name="register_user", arguments='{"name": "Doe", "age": 9}'),
],
id="parallel_calls_nonstreaming",
),
]
@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", TEST_CASES)
def test_tool_call(
streaming: bool,
model_output: str,
expected_tool_calls: list[FunctionCall],
default_tokenizer: TokenizerLike,
):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
default_tokenizer
)
content, tool_calls = run_tool_extraction(
tool_parser, model_output, streaming=streaming
)
assert len(tool_calls) == len(expected_tool_calls)
for actual, expected in zip(tool_calls, expected_tool_calls):
assert actual.type == "function"
assert actual.function == expected
def test_streaming_tool_call_with_large_steps(default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
default_tokenizer
)
model_output_deltas = [
"<|python_start|>[get_weather(city='LA', metric='C'), "
"get_weather(), "
"do_something_cool(steps=[])]<|python_end|>",
]
reconstructor = run_tool_extraction_streaming(
tool_parser, model_output_deltas, assert_one_tool_per_delta=False
)
assert reconstructor.other_content == ""
assert len(reconstructor.tool_calls) == 3
assert reconstructor.tool_calls[0].function == SIMPLE_FUNCTION_CALL
assert reconstructor.tool_calls[1].function == PARAMETERLESS_FUNCTION_CALL
assert reconstructor.tool_calls[2].function == EMPTY_LIST_FUNCTION_CALL
@pytest.mark.parametrize("streaming", [False])
def test_regex_timeout_handling(streaming: bool, default_tokenizer: TokenizerLike):
"""test regex timeout is handled gracefully"""
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
default_tokenizer
)
fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2
# create a mock regex that raises TimeoutError
mock_regex = MagicMock()
mock_regex.match.side_effect = TimeoutError("Regex timeout")
with patch.object(tool_parser, "TOOL_CALL_REGEX", mock_regex):
content, tool_calls = run_tool_extraction(
tool_parser, fake_problematic_input, streaming=streaming
)
# should treat as regular text when regex times out
assert content == fake_problematic_input
assert len(tool_calls) == 0
mock_regex.match.assert_called_once()

View File

@@ -0,0 +1,251 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock, patch
import pytest
from tests.entrypoints.openai.tool_parsers.utils import (
run_tool_extraction,
run_tool_extraction_streaming,
)
from vllm.entrypoints.openai.protocol import FunctionCall
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser, ToolParserManager
# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1
SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')"
SIMPLE_FUNCTION_CALL = FunctionCall(
name="get_weather",
arguments='{"city": "San Francisco", "metric": "celsius"}',
)
MORE_TYPES_FUNCTION_OUTPUT = (
"register_user(name='John Doe', "
"age=37, "
"address={'city': 'San Francisco', 'state': 'CA'}, "
"role=None, "
"passed_test=True, "
"aliases=['John', 'Johnny'])"
)
MORE_TYPES_FUNCTION_OUTPUT_JSON_LITERALS = (
"register_user(name='John Doe', "
"age=37, "
"address={'city': 'San Francisco', 'state': 'CA'}, "
"role=null, "
"passed_test=true, "
"aliases=['John', 'Johnny'])"
)
MORE_TYPES_FUNCTION_CALL = FunctionCall(
name="register_user",
arguments='{"name": "John Doe", '
'"age": 37, '
'"address": {"city": "San Francisco", "state": "CA"}, '
'"role": null, '
'"passed_test": true, '
'"aliases": ["John", "Johnny"]}',
)
PARAMETERLESS_FUNCTION_OUTPUT = "get_weather()"
PARAMETERLESS_FUNCTION_CALL = FunctionCall(
name="get_weather",
arguments="{}",
)
EMPTY_DICT_FUNCTION_OUTPUT = "do_something_cool(additional_data={})"
EMPTY_DICT_FUNCTION_CALL = FunctionCall(
name="do_something_cool",
arguments='{"additional_data": {}}',
)
EMPTY_LIST_FUNCTION_OUTPUT = "do_something_cool(steps=[])"
EMPTY_LIST_FUNCTION_CALL = FunctionCall(
name="do_something_cool",
arguments='{"steps": []}',
)
ESCAPED_STRING_FUNCTION_OUTPUT = (
r"get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')"
)
ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
name="get_weather",
arguments='{"city": "Martha\'s Vineyard", "metric": "\\"cool units\\""}',
)
@pytest.mark.parametrize("streaming", [True, False])
def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
default_tokenizer
)
model_output = "How can I help you today?"
content, tool_calls = run_tool_extraction(
tool_parser, model_output, streaming=streaming
)
assert content == model_output
assert len(tool_calls) == 0
TEST_CASES = [
pytest.param(
True,
f"<function_calls>{SIMPLE_FUNCTION_OUTPUT}</function_calls>",
[SIMPLE_FUNCTION_CALL],
id="simple_streaming",
),
pytest.param(
False,
f"<function_calls>{SIMPLE_FUNCTION_OUTPUT}</function_calls>",
[SIMPLE_FUNCTION_CALL],
id="simple_nonstreaming",
),
pytest.param(
True,
f"<function_calls>{MORE_TYPES_FUNCTION_OUTPUT}</function_calls>",
[MORE_TYPES_FUNCTION_CALL],
id="more_types_streaming",
),
pytest.param(
False,
f"<function_calls>{MORE_TYPES_FUNCTION_OUTPUT}</function_calls>",
[MORE_TYPES_FUNCTION_CALL],
id="more_types_nonstreaming",
),
pytest.param(
True,
f"<function_calls>{MORE_TYPES_FUNCTION_OUTPUT_JSON_LITERALS}</function_calls>",
[MORE_TYPES_FUNCTION_CALL],
id="more_types_streaming_json_literals",
),
pytest.param(
False,
f"<function_calls>{MORE_TYPES_FUNCTION_OUTPUT_JSON_LITERALS}</function_calls>",
[MORE_TYPES_FUNCTION_CALL],
id="more_types_nonstreaming_json_literals",
),
pytest.param(
True,
f"<function_calls>{PARAMETERLESS_FUNCTION_OUTPUT}</function_calls>",
[PARAMETERLESS_FUNCTION_CALL],
id="parameterless_streaming",
),
pytest.param(
False,
f"<function_calls>{PARAMETERLESS_FUNCTION_OUTPUT}</function_calls>",
[PARAMETERLESS_FUNCTION_CALL],
id="parameterless_nonstreaming",
),
pytest.param(
True,
f"<function_calls>{EMPTY_DICT_FUNCTION_OUTPUT}</function_calls>",
[EMPTY_DICT_FUNCTION_CALL],
id="empty_dict_streaming",
),
pytest.param(
False,
f"<function_calls>{EMPTY_DICT_FUNCTION_OUTPUT}</function_calls>",
[EMPTY_DICT_FUNCTION_CALL],
id="empty_dict_nonstreaming",
),
pytest.param(
True,
f"<function_calls>{EMPTY_LIST_FUNCTION_OUTPUT}</function_calls>",
[EMPTY_LIST_FUNCTION_CALL],
id="empty_list_streaming",
),
pytest.param(
False,
f"<function_calls>{EMPTY_LIST_FUNCTION_OUTPUT}</function_calls>",
[EMPTY_LIST_FUNCTION_CALL],
id="empty_list_nonstreaming",
),
pytest.param(
True,
f"<function_calls>{ESCAPED_STRING_FUNCTION_OUTPUT}</function_calls>",
[ESCAPED_STRING_FUNCTION_CALL],
id="escaped_string_streaming",
),
pytest.param(
False,
f"<function_calls>{ESCAPED_STRING_FUNCTION_OUTPUT}</function_calls>",
[ESCAPED_STRING_FUNCTION_CALL],
id="escaped_string_nonstreaming",
),
pytest.param(
True,
f"<function_calls>{SIMPLE_FUNCTION_OUTPUT}\n{MORE_TYPES_FUNCTION_OUTPUT}</function_calls>",
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL],
id="parallel_calls_streaming",
),
pytest.param(
False,
f"<function_calls>{SIMPLE_FUNCTION_OUTPUT}\n{MORE_TYPES_FUNCTION_OUTPUT}</function_calls>",
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL],
id="parallel_calls_nonstreaming",
),
]
@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", TEST_CASES)
def test_tool_call(
streaming: bool,
model_output: str,
expected_tool_calls: list[FunctionCall],
default_tokenizer: TokenizerLike,
):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
default_tokenizer
)
content, tool_calls = run_tool_extraction(
tool_parser, model_output, streaming=streaming
)
assert content is None
assert len(tool_calls) == len(expected_tool_calls)
for actual, expected in zip(tool_calls, expected_tool_calls):
assert actual.type == "function"
assert actual.function == expected
def test_streaming_tool_call_with_large_steps(default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
default_tokenizer
)
model_output_deltas = [
"<function_calls>get_weather(city='San",
" Francisco', metric='celsius')\n"
f"{PARAMETERLESS_FUNCTION_OUTPUT}\n"
f"{EMPTY_LIST_FUNCTION_OUTPUT}</function_calls>",
]
reconstructor = run_tool_extraction_streaming(
tool_parser, model_output_deltas, assert_one_tool_per_delta=False
)
assert reconstructor.other_content == ""
assert len(reconstructor.tool_calls) == 3
assert reconstructor.tool_calls[0].function == SIMPLE_FUNCTION_CALL
assert reconstructor.tool_calls[1].function == PARAMETERLESS_FUNCTION_CALL
assert reconstructor.tool_calls[2].function == EMPTY_LIST_FUNCTION_CALL
@pytest.mark.parametrize("streaming", [False])
def test_regex_timeout_handling(streaming: bool, default_tokenizer: TokenizerLike):
"""test regex timeout is handled gracefully"""
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
default_tokenizer
)
fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2
# create a mock regex that raises TimeoutError
mock_regex = MagicMock()
mock_regex.match.side_effect = TimeoutError("Regex timeout")
with patch.object(tool_parser, "TOOL_CALL_REGEX", mock_regex):
content, tool_calls = run_tool_extraction(
tool_parser, fake_problematic_input, streaming=streaming
)
# should treat as regular text when regex times out
assert content == fake_problematic_input
assert len(tool_calls) == 0
mock_regex.match.assert_called_once()

View File

@@ -0,0 +1,359 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import jsonschema
import openai
import pytest
import pytest_asyncio
from rapidfuzz import fuzz
from ....utils import RemoteOpenAIServer
MODEL_NAME = "openai/gpt-oss-20b"
@pytest.fixture(scope="module")
def server():
args = [
"--max-model-len",
"8192",
"--enforce-eager",
"--enable-auto-tool-choice",
"--tool-call-parser",
"openai",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client(server):
"""Async fixture providing an OpenAI-compatible vLLM client."""
async with server.get_async_client() as async_client:
yield async_client
# ==========================================================
# Tool Definitions
# ==========================================================
TOOLS = [
{
"type": "function",
"function": {
"name": "calculator",
"description": "Performs basic arithmetic calculations.",
"parameters": {
"type": "object",
"properties": {
"expression": {
"type": "string",
"description": (
"Arithmetic expression to evaluate, e.g. '123 + 456'."
),
}
},
"required": ["expression"],
},
},
},
{
"type": "function",
"function": {
"name": "get_time",
"description": "Retrieves the current local time for a given city.",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "City name, e.g. 'New York'.",
}
},
"required": ["city"],
},
},
},
]
# ==========================================================
# Message Examples
# ==========================================================
MESSAGES_CALC = [
{"role": "user", "content": "Calculate 123 + 456 using the calculator."}
]
MESSAGES_GET_TIME = [
{"role": "user", "content": "What is the current time in New York?"}
]
MESSAGES_MULTIPLE_CALLS = [
{
"role": "system",
"content": (
"You can call multiple tools. "
"When using more than one, return single JSON object with tool_calls array"
"containing each tool call with its function name and arguments. "
"Do not output multiple JSON objects separately."
),
},
{
"role": "user",
"content": "First, calculate 7 * 8 using the calculator. "
"Then, use get_time to tell me the current time in New York.",
},
]
MESSAGES_INVALID_CALL = [
{
"role": "user",
"content": "Can you help with something, "
"but dont actually perform any calculation?",
}
]
# Expected outputs
FUNC_CALC = "calculator"
FUNC_ARGS_CALC = '{"expression":"123 + 456"}'
FUNC_TIME = "get_time"
FUNC_ARGS_TIME = '{"city": "New York"}'
# ==========================================================
# Utility to extract reasoning and tool calls
# ==========================================================
def extract_reasoning_and_calls(chunks: list) -> tuple[str, list[str], list[str]]:
"""
Extract accumulated reasoning text and tool call arguments
from streaming chunks.
"""
reasoning_content: str = ""
tool_calls: dict[int, dict[str, str]] = {}
for chunk in chunks:
choice = getattr(chunk.choices[0], "delta", None)
if not choice:
continue
if hasattr(choice, "reasoning_content") and choice.reasoning_content:
reasoning_content += choice.reasoning_content
for tc in getattr(choice, "tool_calls", []) or []:
idx = getattr(tc, "index", 0)
tool_entry = tool_calls.setdefault(idx, {"name": "", "arguments": ""})
if getattr(tc, "function", None):
func = tc.function
if getattr(func, "name", None):
tool_entry["name"] = func.name
if getattr(func, "arguments", None):
tool_entry["arguments"] += func.arguments
function_names: list[str] = [v["name"] for _, v in sorted(tool_calls.items())]
arguments: list[str] = [v["arguments"] for _, v in sorted(tool_calls.items())]
return reasoning_content, arguments, function_names
# ==========================================================
# Test Scenarios
# ==========================================================
@pytest.mark.asyncio
async def test_calculator_tool_call_and_argument_accuracy(client: openai.AsyncOpenAI):
"""Verify calculator tool call is made and arguments are accurate."""
response = await client.chat.completions.create(
model=MODEL_NAME,
messages=MESSAGES_CALC,
tools=TOOLS,
temperature=0.0,
stream=False,
)
message = response.choices[0].message
tool_calls = getattr(message, "tool_calls", [])
assert tool_calls, "No tool calls detected"
calc_call = next((c for c in tool_calls if c.function.name == FUNC_CALC), None)
assert calc_call, "Calculator function not called"
raw_args = calc_call.function.arguments
assert raw_args, "Calculator arguments missing"
assert "123" in raw_args and "456" in raw_args, (
f"Expected values not in raw arguments: {raw_args}"
)
try:
parsed_args = json.loads(raw_args)
except json.JSONDecodeError:
pytest.fail(f"Invalid JSON in calculator arguments: {raw_args}")
expected_expr = "123 + 456"
actual_expr = parsed_args.get("expression", "")
similarity = fuzz.ratio(actual_expr, expected_expr)
assert similarity > 90, (
f"Expression mismatch: expected '{expected_expr}' "
f"got '{actual_expr}' (similarity={similarity}%)"
)
@pytest.mark.asyncio
async def test_streaming_tool_call_get_time_with_reasoning(client: openai.AsyncOpenAI):
"""Verify streamed reasoning and tool call behavior for get_time."""
stream = await client.chat.completions.create(
model=MODEL_NAME,
messages=MESSAGES_GET_TIME,
tools=TOOLS,
temperature=0.0,
stream=True,
)
chunks = [chunk async for chunk in stream]
reasoning, arguments, function_names = extract_reasoning_and_calls(chunks)
assert FUNC_TIME in function_names, "get_time function not called"
assert any("New York" in arg for arg in arguments), (
f"Expected get_time arguments for New York not found in {arguments}"
)
assert len(reasoning) > 0, "Expected reasoning content missing"
assert any(keyword in reasoning for keyword in ["New York", "time", "current"]), (
f"Reasoning is not relevant to the request: {reasoning}"
)
@pytest.mark.asyncio
async def test_streaming_multiple_tools(client: openai.AsyncOpenAI):
"""Test streamed multi-tool response with reasoning."""
stream = await client.chat.completions.create(
model=MODEL_NAME,
messages=MESSAGES_MULTIPLE_CALLS,
tools=TOOLS,
temperature=0.0,
stream=True,
)
chunks = [chunk async for chunk in stream]
reasoning, arguments, function_names = extract_reasoning_and_calls(chunks)
try:
assert FUNC_CALC in function_names, (
f"Calculator tool missing — found {function_names}"
)
assert FUNC_TIME in function_names, (
f"Time tool missing — found {function_names}"
)
assert len(reasoning) > 0, "Expected reasoning content in streamed response"
except AssertionError as e:
print(f"ERROR: {e}")
@pytest.mark.asyncio
async def test_invalid_tool_call(client: openai.AsyncOpenAI):
"""
Verify that ambiguous instructions that should not trigger a tool
do not produce any tool calls.
"""
response = await client.chat.completions.create(
model=MODEL_NAME,
messages=MESSAGES_INVALID_CALL,
tools=TOOLS,
temperature=0.0,
stream=False,
)
message = response.choices[0].message
assert message is not None, "Expected message in response"
assert hasattr(message, "content"), "Expected 'content' field in message"
tool_calls = getattr(message, "tool_calls", [])
assert not tool_calls, (
f"Model unexpectedly attempted a tool call on invalid input: {tool_calls}"
)
@pytest.mark.asyncio
async def test_tool_call_with_temperature(client: openai.AsyncOpenAI):
"""
Verify model produces valid tool or text output
under non-deterministic sampling.
"""
response = await client.chat.completions.create(
model=MODEL_NAME,
messages=MESSAGES_CALC,
tools=TOOLS,
temperature=0.7,
stream=False,
)
message = response.choices[0].message
assert message is not None, "Expected non-empty message in response"
assert message.tool_calls or message.content, (
"Response missing both text and tool calls"
)
print(f"\nTool calls: {message.tool_calls}")
print(f"Text: {message.content}")
@pytest.mark.asyncio
async def test_tool_response_schema_accuracy(client: openai.AsyncOpenAI):
"""Validate that tool call arguments adhere to their declared JSON schema."""
response = await client.chat.completions.create(
model=MODEL_NAME,
messages=MESSAGES_MULTIPLE_CALLS,
tools=TOOLS,
temperature=0.0,
)
calls = response.choices[0].message.tool_calls
assert calls, "No tool calls produced"
for call in calls:
func_name = call.function.name
args = json.loads(call.function.arguments)
schema: dict[str, object] | None = None
for tool_entry in TOOLS:
function_def = tool_entry.get("function")
if (
function_def
and isinstance(function_def, dict)
and function_def.get("name") == func_name
):
schema = function_def.get("parameters")
break
assert schema is not None, f"No matching tool schema found for {func_name}"
jsonschema.validate(instance=args, schema=schema)
@pytest.mark.asyncio
async def test_semantic_consistency_with_temperature(client: openai.AsyncOpenAI):
"""Test that temperature variation doesn't cause contradictory reasoning."""
responses = []
for temp in [0.0, 0.5, 1.0]:
resp = await client.chat.completions.create(
model=MODEL_NAME,
messages=MESSAGES_CALC,
tools=TOOLS,
temperature=temp,
)
text = (resp.choices[0].message.content or "").strip()
responses.append(text)
# Compare fuzzy similarity between low- and mid-temperature outputs
low_mid_sim = fuzz.ratio(responses[0], responses[1])
assert low_mid_sim > 60, (
f"Semantic drift too large between T=0.0 and T=0.5 ({low_mid_sim}%)"
)

View File

@@ -0,0 +1,231 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock, patch
import pytest
from tests.entrypoints.openai.tool_parsers.utils import (
run_tool_extraction,
run_tool_extraction_streaming,
)
from vllm.entrypoints.openai.protocol import FunctionCall
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser, ToolParserManager
# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1
SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')"
SIMPLE_FUNCTION_CALL = FunctionCall(
name="get_weather",
arguments='{"city": "San Francisco", "metric": "celsius"}',
)
MORE_TYPES_FUNCTION_OUTPUT = (
"register_user(name='John Doe', "
"age=37, "
"address={'city': 'San Francisco', 'state': 'CA'}, "
"role=None, "
"passed_test=True, "
"aliases=['John', 'Johnny'])"
)
MORE_TYPES_FUNCTION_CALL = FunctionCall(
name="register_user",
arguments='{"name": "John Doe", '
'"age": 37, '
'"address": {"city": "San Francisco", "state": "CA"}, '
'"role": null, '
'"passed_test": true, '
'"aliases": ["John", "Johnny"]}',
)
PARAMETERLESS_FUNCTION_OUTPUT = "get_weather()"
PARAMETERLESS_FUNCTION_CALL = FunctionCall(
name="get_weather",
arguments="{}",
)
EMPTY_DICT_FUNCTION_OUTPUT = "do_something_cool(additional_data={})"
EMPTY_DICT_FUNCTION_CALL = FunctionCall(
name="do_something_cool",
arguments='{"additional_data": {}}',
)
EMPTY_LIST_FUNCTION_OUTPUT = "do_something_cool(steps=[])"
EMPTY_LIST_FUNCTION_CALL = FunctionCall(
name="do_something_cool",
arguments='{"steps": []}',
)
ESCAPED_STRING_FUNCTION_OUTPUT = (
r"get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')"
)
ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
name="get_weather",
arguments='{"city": "Martha\'s Vineyard", "metric": "\\"cool units\\""}',
)
@pytest.mark.parametrize("streaming", [True, False])
def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
default_tokenizer
)
model_output = "How can I help you today?"
content, tool_calls = run_tool_extraction(
tool_parser, model_output, streaming=streaming
)
assert content == model_output
assert len(tool_calls) == 0
TEST_CASES = [
pytest.param(
True,
f"[{SIMPLE_FUNCTION_OUTPUT}]",
[SIMPLE_FUNCTION_CALL],
id="simple_streaming",
),
pytest.param(
False,
f"[{SIMPLE_FUNCTION_OUTPUT}]",
[SIMPLE_FUNCTION_CALL],
id="simple_nonstreaming",
),
pytest.param(
True,
f"[{MORE_TYPES_FUNCTION_OUTPUT}]",
[MORE_TYPES_FUNCTION_CALL],
id="more_types_streaming",
),
pytest.param(
False,
f"[{MORE_TYPES_FUNCTION_OUTPUT}]",
[MORE_TYPES_FUNCTION_CALL],
id="more_types_nonstreaming",
),
pytest.param(
True,
f"[{PARAMETERLESS_FUNCTION_OUTPUT}]",
[PARAMETERLESS_FUNCTION_CALL],
id="parameterless_streaming",
),
pytest.param(
False,
f"[{PARAMETERLESS_FUNCTION_OUTPUT}]",
[PARAMETERLESS_FUNCTION_CALL],
id="parameterless_nonstreaming",
),
pytest.param(
True,
f"[{EMPTY_DICT_FUNCTION_OUTPUT}]",
[EMPTY_DICT_FUNCTION_CALL],
id="empty_dict_streaming",
),
pytest.param(
False,
f"[{EMPTY_DICT_FUNCTION_OUTPUT}]",
[EMPTY_DICT_FUNCTION_CALL],
id="empty_dict_nonstreaming",
),
pytest.param(
True,
f"[{EMPTY_LIST_FUNCTION_OUTPUT}]",
[EMPTY_LIST_FUNCTION_CALL],
id="empty_list_streaming",
),
pytest.param(
False,
f"[{EMPTY_LIST_FUNCTION_OUTPUT}]",
[EMPTY_LIST_FUNCTION_CALL],
id="empty_list_nonstreaming",
),
pytest.param(
True,
f"[{ESCAPED_STRING_FUNCTION_OUTPUT}]",
[ESCAPED_STRING_FUNCTION_CALL],
id="escaped_string_streaming",
),
pytest.param(
False,
f"[{ESCAPED_STRING_FUNCTION_OUTPUT}]",
[ESCAPED_STRING_FUNCTION_CALL],
id="escaped_string_nonstreaming",
),
pytest.param(
True,
f"[{SIMPLE_FUNCTION_OUTPUT}, {MORE_TYPES_FUNCTION_OUTPUT}]",
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL],
id="parallel_calls_streaming",
),
pytest.param(
False,
f"[{SIMPLE_FUNCTION_OUTPUT}, {MORE_TYPES_FUNCTION_OUTPUT}]",
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL],
id="parallel_calls_nonstreaming",
),
]
@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", TEST_CASES)
def test_tool_call(
streaming: bool,
model_output: str,
expected_tool_calls: list[FunctionCall],
default_tokenizer: TokenizerLike,
):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
default_tokenizer
)
content, tool_calls = run_tool_extraction(
tool_parser, model_output, streaming=streaming
)
assert content is None
assert len(tool_calls) == len(expected_tool_calls)
for actual, expected in zip(tool_calls, expected_tool_calls):
assert actual.type == "function"
assert actual.function == expected
def test_streaming_tool_call_with_large_steps(default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
default_tokenizer
)
model_output_deltas = [
"[get_weather(city='San",
" Francisco', metric='celsius'), "
f"{PARAMETERLESS_FUNCTION_OUTPUT}, "
f"{EMPTY_LIST_FUNCTION_OUTPUT}]",
]
reconstructor = run_tool_extraction_streaming(
tool_parser, model_output_deltas, assert_one_tool_per_delta=False
)
assert reconstructor.other_content == ""
assert len(reconstructor.tool_calls) == 3
assert reconstructor.tool_calls[0].function == SIMPLE_FUNCTION_CALL
assert reconstructor.tool_calls[1].function == PARAMETERLESS_FUNCTION_CALL
assert reconstructor.tool_calls[2].function == EMPTY_LIST_FUNCTION_CALL
@pytest.mark.parametrize("streaming", [False])
def test_regex_timeout_handling(streaming: bool, default_tokenizer: TokenizerLike):
"""test regex timeout is handled gracefully"""
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
default_tokenizer
)
fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2
# create a mock regex that raises TimeoutError
mock_regex = MagicMock()
mock_regex.match.side_effect = TimeoutError("Regex timeout")
with patch.object(tool_parser, "TOOL_CALL_REGEX", mock_regex):
content, tool_calls = run_tool_extraction(
tool_parser, fake_problematic_input, streaming=streaming
)
# should treat as regular text when regex times out
assert content == fake_problematic_input
assert len(tool_calls) == 0
mock_regex.match.assert_called_once()

View File

@@ -0,0 +1,167 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaMessage,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser
class StreamingToolReconstructor:
def __init__(self, assert_one_tool_per_delta: bool = True):
self.tool_calls: list[ToolCall] = []
self.other_content: str = ""
self._assert_one_tool_per_delta = assert_one_tool_per_delta
def append_delta(self, delta: DeltaMessage):
if delta.content is not None:
self.other_content += delta.content
else:
assert delta.tool_calls, (
"Streaming results should have either content or tool calls (or both)"
)
if self._assert_one_tool_per_delta:
# Note: This isn't strictly required by the API and may not be
# possible to adhere to depending on the token space and number of
# tokens per streamed response from the model, but it is required
# by tool_use tests, so we enforce it here by default also.
assert len(delta.tool_calls) < 2, (
"Streaming should include only one tool call per update."
)
for call_delta in delta.tool_calls:
assert call_delta.type is None or call_delta.type == "function", (
"Streaming tool calls should only emit function calls. Got "
f"{call_delta.type}"
)
current_tool_call = (
self.tool_calls[call_delta.index]
if call_delta.index < len(self.tool_calls)
else None
)
if current_tool_call:
assert not call_delta.function.name, (
"Streaming tool calls should emit the full function name "
f"exactly once. Got {call_delta.function.name}"
)
assert not call_delta.id, (
"Streaming tool calls must emit function id only once. Got "
f"{call_delta.id}"
)
assert call_delta.index == len(self.tool_calls) - 1, (
f"Incorrect index for tool delta. Got {call_delta.index}, "
f"expected {len(self.tool_calls) - 1}"
)
current_tool_call.function.arguments += call_delta.function.arguments
else:
assert call_delta.id is not None, (
"Streaming tool calls must have an id on first appearance"
)
assert call_delta.function.name is not None, (
"Streaming tool calls must have a function name on first appearance"
)
assert call_delta.index == len(self.tool_calls), (
f"Incorrect index for tool delta. Got {call_delta.index}, "
f"expected {len(self.tool_calls)}"
)
self.tool_calls.append(
ToolCall(
id=call_delta.id,
function=FunctionCall(
name=call_delta.function.name,
arguments=call_delta.function.arguments or "",
),
)
)
def run_tool_extraction(
tool_parser: ToolParser,
model_output: str,
request: ChatCompletionRequest | None = None,
streaming: bool = False,
assert_one_tool_per_delta: bool = True,
) -> tuple[str | None, list[ToolCall]]:
if streaming:
reconstructor = run_tool_extraction_streaming(
tool_parser,
model_output,
request,
assert_one_tool_per_delta=assert_one_tool_per_delta,
)
return reconstructor.other_content or None, reconstructor.tool_calls
else:
extracted = run_tool_extraction_nonstreaming(tool_parser, model_output, request)
assert extracted.tools_called == bool(extracted.tool_calls)
return extracted.content, extracted.tool_calls
def run_tool_extraction_nonstreaming(
tool_parser: ToolParser,
model_output: str,
request: ChatCompletionRequest | None = None,
) -> ExtractedToolCallInformation:
request = request or ChatCompletionRequest(messages=[], model="test-model")
return tool_parser.extract_tool_calls(model_output, request)
def split_string_into_token_deltas(tokenizer: TokenizerLike, text: str) -> list[str]:
# Split a string into a series of deltas using the provided tokenizer. Each
# delta will be the string equivalent of a single token.
token_ids = tokenizer.encode(text, add_special_tokens=False)
previously_decoded_text = ""
deltas = []
for i in range(1, len(token_ids) + 1):
current_tokens = token_ids[:i]
current_text = tokenizer.decode(current_tokens)
new_text = current_text[len(previously_decoded_text) :]
previously_decoded_text = current_text
deltas.append(new_text)
return deltas
def run_tool_extraction_streaming(
tool_parser: ToolParser,
model_deltas: Iterable[str],
request: ChatCompletionRequest | None = None,
assert_one_tool_per_delta: bool = True,
) -> StreamingToolReconstructor:
if isinstance(model_deltas, str):
model_deltas = split_string_into_token_deltas(
tool_parser.model_tokenizer, model_deltas
)
request = request or ChatCompletionRequest(messages=[], model="test-model")
reconstructor = StreamingToolReconstructor(
assert_one_tool_per_delta=assert_one_tool_per_delta
)
previous_text = ""
previous_tokens: list[int] = []
for delta in model_deltas:
token_delta = [
tool_parser.vocab.get(token)
for token in tool_parser.model_tokenizer.tokenize(delta)
if token in tool_parser.vocab
]
current_text = previous_text + delta
current_tokens = previous_tokens + token_delta
delta_message = tool_parser.extract_tool_calls_streaming(
previous_text,
current_text,
delta,
previous_tokens,
current_tokens,
token_delta,
request,
)
if delta_message is not None:
reconstructor.append_delta(delta_message)
previous_text = current_text
previous_tokens = current_tokens
return reconstructor