Sync from v0.13

This commit is contained in:
2026-01-19 10:38:50 +08:00
parent b2ef04d792
commit 5aef6c175a
3714 changed files with 854317 additions and 89342 deletions

View File

View File

@@ -0,0 +1,61 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.tokenizers import get_tokenizer
from vllm.tool_parsers.deepseekv31_tool_parser import (
DeepSeekV31ToolParser,
)
MODEL = "deepseek-ai/DeepSeek-V3.1"
@pytest.fixture(scope="module")
def deepseekv31_tokenizer():
return get_tokenizer(tokenizer_name=MODEL)
@pytest.fixture
def parser(deepseekv31_tokenizer):
return DeepSeekV31ToolParser(deepseekv31_tokenizer)
def test_extract_tool_calls_with_tool(parser):
model_output = (
"normal text"
+ "<tool▁calls▁begin>"
+ '<tool▁call▁begin>foo<tool▁sep>{"x":1}<tool▁call▁end>'
+ "<tool▁calls▁end>"
)
result = parser.extract_tool_calls(model_output, None)
assert result.tools_called
assert len(result.tool_calls) == 1
assert result.tool_calls[0].function.name == "foo"
assert result.tool_calls[0].function.arguments == '{"x":1}'
assert result.content == "normal text"
def test_extract_tool_calls_with_multiple_tools(parser):
model_output = (
"some prefix text"
+ "<tool▁calls▁begin>"
+ '<tool▁call▁begin>foo<tool▁sep>{"x":1}<tool▁call▁end>'
+ '<tool▁call▁begin>bar<tool▁sep>{"y":2}<tool▁call▁end>'
+ "<tool▁calls▁end>"
+ " some suffix text"
)
result = parser.extract_tool_calls(model_output, None)
assert result.tools_called
assert len(result.tool_calls) == 2
assert result.tool_calls[0].function.name == "foo"
assert result.tool_calls[0].function.arguments == '{"x":1}'
assert result.tool_calls[1].function.name == "bar"
assert result.tool_calls[1].function.arguments == '{"y":2}'
# prefix is content
assert result.content == "some prefix text"

View File

@@ -0,0 +1,359 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
import json
from collections.abc import Generator
import pytest
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaMessage,
FunctionCall,
ToolCall,
)
from vllm.tokenizers import TokenizerLike, get_tokenizer
from vllm.tokenizers.detokenizer_utils import detokenize_incrementally
from vllm.tool_parsers.ernie45_tool_parser import Ernie45ToolParser
# Use a common model that is likely to be available
MODEL = "baidu/ERNIE-4.5-21B-A3B-Thinking"
@pytest.fixture(scope="module")
def ernie45_tokenizer():
return get_tokenizer(tokenizer_name=MODEL, trust_remote_code=True)
@pytest.fixture
def ernie45_tool_parser(ernie45_tokenizer):
return Ernie45ToolParser(ernie45_tokenizer)
def assert_tool_calls(
actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall]
):
assert len(actual_tool_calls) == len(expected_tool_calls)
for actual_tool_call, expected_tool_call in zip(
actual_tool_calls, expected_tool_calls
):
assert isinstance(actual_tool_call.id, str)
assert len(actual_tool_call.id) > 0
assert actual_tool_call.type == "function"
assert actual_tool_call.function.name == expected_tool_call.function.name
# Compare arguments as JSON objects to handle formatting differences
actual_args = json.loads(actual_tool_call.function.arguments)
expected_args = json.loads(expected_tool_call.function.arguments)
assert actual_args == expected_args
def test_extract_tool_calls_no_tools(ernie45_tool_parser):
model_output = "This is a test"
extracted_tool_calls = ernie45_tool_parser.extract_tool_calls(
model_output, request=None
) # type: ignore[arg-type]
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
assert extracted_tool_calls.content == model_output
@pytest.mark.parametrize(
ids=[
"single_tool_call",
"multiple_tool_calls",
"tool_call_with_content_before",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
(
"""<tool_call>
{"name": "get_current_temperature", "arguments": {"location": "Beijing"}}
</tool_call>
""",
[
ToolCall(
function=FunctionCall(
name="get_current_temperature",
arguments=json.dumps(
{
"location": "Beijing",
}
),
)
)
],
None,
),
(
"""<tool_call>
{"name": "get_current_temperature", "arguments": {"location": "Beijing"}}
</tool_call>
<tool_call>
{"name": "get_temperature_unit", "arguments": {"location": "Guangzhou", "unit": "c"}}
</tool_call>
""",
[
ToolCall(
function=FunctionCall(
name="get_current_temperature",
arguments=json.dumps(
{
"location": "Beijing",
}
),
)
),
ToolCall(
function=FunctionCall(
name="get_temperature_unit",
arguments=json.dumps(
{
"location": "Guangzhou",
"unit": "c",
}
),
)
),
],
None,
),
(
"""I need to call two tools to handle these two issues separately.
</think>
<tool_call>
{"name": "get_current_temperature", "arguments": {"location": "Beijing"}}
</tool_call>
<tool_call>
{"name": "get_temperature_unit", "arguments": {"location": "Guangzhou", "unit": "c"}}
</tool_call>
""",
[
ToolCall(
function=FunctionCall(
name="get_current_temperature",
arguments=json.dumps(
{
"location": "Beijing",
}
),
)
),
ToolCall(
function=FunctionCall(
name="get_temperature_unit",
arguments=json.dumps(
{
"location": "Guangzhou",
"unit": "c",
}
),
)
),
],
"I need to call two tools to handle these two issues separately.\n</think>",
),
],
)
def test_extract_tool_calls(
ernie45_tool_parser, model_output, expected_tool_calls, expected_content
):
extracted_tool_calls = ernie45_tool_parser.extract_tool_calls(
model_output, request=None
) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
assert extracted_tool_calls.content == expected_content
def stream_delta_message_generator(
ernie45_tool_parser: Ernie45ToolParser,
ernie45_tokenizer: TokenizerLike,
model_output: str,
request: ChatCompletionRequest | None = None,
) -> Generator[DeltaMessage, None, None]:
all_token_ids = ernie45_tokenizer.encode(model_output, add_special_tokens=False)
previous_text = ""
previous_tokens = None
prefix_offset = 0
read_offset = 0
for i, delta_token in enumerate(all_token_ids):
delta_token_ids = [delta_token]
previous_token_ids = all_token_ids[:i]
current_token_ids = all_token_ids[: i + 1]
(new_tokens, delta_text, new_prefix_offset, new_read_offset) = (
detokenize_incrementally(
tokenizer=ernie45_tokenizer,
all_input_ids=current_token_ids,
prev_tokens=previous_tokens,
prefix_offset=prefix_offset,
read_offset=read_offset,
skip_special_tokens=False,
spaces_between_special_tokens=True,
)
)
current_text = previous_text + delta_text
delta_message = ernie45_tool_parser.extract_tool_calls_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
delta_token_ids,
request=request,
)
if delta_message:
yield delta_message
previous_text = current_text
previous_tokens = (
previous_tokens + new_tokens if previous_tokens else new_tokens
)
prefix_offset = new_prefix_offset
read_offset = new_read_offset
@pytest.mark.parametrize(
ids=[
"single_tool_call",
"multiple_tool_calls",
"tool_call_with_content_before",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
(
"""<tool_call>
{"name": "get_current_temperature", "arguments": {"location": "Beijing"}}
</tool_call>
""",
[
ToolCall(
function=FunctionCall(
name="get_current_temperature",
arguments=json.dumps(
{
"location": "Beijing",
}
),
)
)
],
None,
),
(
"""<tool_call>
{"name": "get_current_temperature", "arguments": {"location": "Beijing"}}
</tool_call>
<tool_call>
{"name": "get_temperature_unit", "arguments": {"location": "Guangzhou", "unit": "c"}}
</tool_call>
""",
[
ToolCall(
function=FunctionCall(
name="get_current_temperature",
arguments=json.dumps(
{
"location": "Beijing",
}
),
)
),
ToolCall(
function=FunctionCall(
name="get_temperature_unit",
arguments=json.dumps(
{
"location": "Guangzhou",
"unit": "c",
}
),
)
),
],
None,
),
(
"""I need to call two tools to handle these two issues separately.
</think>
<tool_call>
{"name": "get_current_temperature", "arguments": {"location": "Beijing"}}
</tool_call>
<tool_call>
{"name": "get_temperature_unit", "arguments": {"location": "Guangzhou", "unit": "c"}}
</tool_call>
""",
[
ToolCall(
function=FunctionCall(
name="get_current_temperature",
arguments=json.dumps(
{
"location": "Beijing",
}
),
)
),
ToolCall(
function=FunctionCall(
name="get_temperature_unit",
arguments=json.dumps(
{
"location": "Guangzhou",
"unit": "c",
}
),
)
),
],
"I need to call two tools to handle these two issues separately.\n</think>",
),
],
)
def test_extract_tool_calls_streaming_incremental(
ernie45_tool_parser,
ernie45_tokenizer,
model_output,
expected_tool_calls,
expected_content,
):
"""Verify the Ernie45 Parser streaming behavior by verifying each chunk is as expected.""" # noqa: E501
request = ChatCompletionRequest(model=MODEL, messages=[], tools=[])
tool_calls_dict = {}
for delta_message in stream_delta_message_generator(
ernie45_tool_parser, ernie45_tokenizer, model_output, request
):
if (
delta_message.role is None
and delta_message.content is None
and delta_message.reasoning is None
and len(delta_message.tool_calls) == 0
):
continue
tool_calls = delta_message.tool_calls
for tool_call_chunk in tool_calls:
index = tool_call_chunk.index
if index not in tool_calls_dict:
if tool_call_chunk.function.arguments is None:
tool_call_chunk.function.arguments = ""
tool_calls_dict[index] = tool_call_chunk
else:
tool_calls_dict[
index
].function.arguments += tool_call_chunk.function.arguments
actual_tool_calls = list(tool_calls_dict.values())
assert len(actual_tool_calls) > 0
# check tool call format
assert_tool_calls(actual_tool_calls, expected_tool_calls)

View File

@@ -0,0 +1,449 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
import json
import pytest
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
from vllm.tokenizers import get_tokenizer
from vllm.tool_parsers.glm4_moe_tool_parser import (
Glm4MoeModelToolParser,
)
pytest.skip("skip glm4_moe parser test", allow_module_level=True)
# Use a common model that is likely to be available
MODEL = "zai-org/GLM-4.5"
@pytest.fixture(scope="module")
def glm4_moe_tokenizer():
return get_tokenizer(tokenizer_name=MODEL)
@pytest.fixture
def glm4_moe_tool_parser(glm4_moe_tokenizer):
return Glm4MoeModelToolParser(glm4_moe_tokenizer)
def assert_tool_calls(
actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall]
):
assert len(actual_tool_calls) == len(expected_tool_calls)
for actual_tool_call, expected_tool_call in zip(
actual_tool_calls, expected_tool_calls
):
assert isinstance(actual_tool_call.id, str)
assert len(actual_tool_call.id) > 0
assert actual_tool_call.type == "function"
assert actual_tool_call.function.name == expected_tool_call.function.name
# Compare arguments as JSON objects to handle formatting differences
actual_args = json.loads(actual_tool_call.function.arguments)
expected_args = json.loads(expected_tool_call.function.arguments)
assert actual_args == expected_args
def test_extract_tool_calls_no_tools(glm4_moe_tool_parser):
model_output = "This is a test"
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
model_output, request=None
) # type: ignore[arg-type]
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
assert extracted_tool_calls.content == model_output
@pytest.mark.parametrize(
ids=[
"single_tool_call",
"multiple_tool_calls",
"tool_call_with_content_before",
"tool_call_with_mixed_args",
"tool_call_with_chinese_content",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
(
"""<tool_call>get_current_weather
<arg_key>city</arg_key>
<arg_value>Dallas</arg_value>
<arg_key>state</arg_key>
<arg_value>TX</arg_value>
<arg_key>unit</arg_key>
<arg_value>fahrenheit</arg_value>
</tool_call>""",
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}
),
)
)
],
None,
),
(
"""<tool_call>get_current_weather
<arg_key>city</arg_key>
<arg_value>Dallas</arg_value>
<arg_key>state</arg_key>
<arg_value>TX</arg_value>
<arg_key>unit</arg_key>
<arg_value>fahrenheit</arg_value>
</tool_call>
<tool_call>get_current_weather
<arg_key>city</arg_key>
<arg_value>Orlando</arg_value>
<arg_key>state</arg_key>
<arg_value>FL</arg_value>
<arg_key>unit</arg_key>
<arg_value>fahrenheit</arg_value>
</tool_call>""",
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}
),
)
),
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{
"city": "Orlando",
"state": "FL",
"unit": "fahrenheit",
}
),
)
),
],
None,
),
(
"""I'll help you check the weather. <tool_call>get_current_weather
<arg_key>city</arg_key>
<arg_value>Seattle</arg_value>
<arg_key>state</arg_key>
<arg_value>WA</arg_value>
<arg_key>unit</arg_key>
<arg_value>celsius</arg_value>
</tool_call>""",
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{
"city": "Seattle",
"state": "WA",
"unit": "celsius",
}
),
)
)
],
"I'll help you check the weather.",
),
(
"""<tool_call>get_current_weather
<arg_key>city</arg_key>
<arg_value>New York</arg_value>
<arg_key>state</arg_key>
<arg_value>NY</arg_value>
<arg_key>unit</arg_key>
<arg_value>celsius</arg_value>
</tool_call>""",
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{
"city": "New York",
"state": "NY",
"unit": "celsius",
}
),
)
)
],
None,
),
(
"""I will help you get the weather.<tool_call>get_weather
<arg_key>city</arg_key>
<arg_value>Beijing</arg_value>
<arg_key>date</arg_key>
<arg_value>2025-08-01</arg_value>
</tool_call>""",
[
ToolCall(
function=FunctionCall(
name="get_weather",
arguments=json.dumps(
{
"city": "Beijing",
"date": "2025-08-01",
}
),
)
)
],
"I will help you get the weather.",
),
],
)
def test_extract_tool_calls(
glm4_moe_tool_parser, model_output, expected_tool_calls, expected_content
):
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
model_output, request=None
) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
assert extracted_tool_calls.content == expected_content
def test_extract_tool_calls_with_thinking_tags(glm4_moe_tool_parser):
"""Test tool extraction when thinking tags are present."""
model_output = """<think>I want to get the weather.</think>
I will help you get the weather.
<tool_call>get_weather
<arg_key>city</arg_key>
<arg_value>Beijing</arg_value>
<arg_key>date</arg_key>
<arg_value>2025-08-01</arg_value>
</tool_call>"""
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
model_output, request=None
) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
assert len(extracted_tool_calls.tool_calls) == 1
assert extracted_tool_calls.tool_calls[0].function.name == "get_weather"
expected_content = """<think>I want to get the weather.</think>
I will help you get the weather."""
assert extracted_tool_calls.content == expected_content
def test_extract_tool_calls_malformed_xml(glm4_moe_tool_parser):
"""Test that malformed XML is handled gracefully."""
model_output = """<tool_call>get_weather
<arg_key>city</arg_key>
<arg_value>Seattle</arg_value>
<arg_key>incomplete_arg
<arg_value>value</arg_value>
</tool_call>"""
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
model_output, request=None
) # type: ignore[arg-type]
# Should handle malformed XML gracefully
# The parser should either extract what it can or return no tool calls
# depending on how robust we want the parsing to be
assert isinstance(extracted_tool_calls.tools_called, bool)
assert isinstance(extracted_tool_calls.tool_calls, list)
def test_extract_tool_calls_empty_arguments(glm4_moe_tool_parser):
"""Test tool calls with no arguments."""
model_output = """<tool_call>get_current_time
</tool_call>"""
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
model_output, request=None
) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
assert len(extracted_tool_calls.tool_calls) == 1
assert extracted_tool_calls.tool_calls[0].function.name == "get_current_time"
# Empty arguments should result in empty JSON object
assert extracted_tool_calls.tool_calls[0].function.arguments == "{}"
def test_extract_tool_calls_mixed_content(glm4_moe_tool_parser):
"""Test extraction with mixed content and multiple tool calls."""
model_output = """I will help you get the weather info.
<tool_call>get_weather
<arg_key>city</arg_key>
<arg_value>Beijing</arg_value>
<arg_key>date</arg_key>
<arg_value>2025-08-01</arg_value>
</tool_call>
meaningwhile, I will also check the weather in Shanghai.
<tool_call>get_weather
<arg_key>city</arg_key>
<arg_value>Shanghai</arg_value>
<arg_key>date</arg_key>
<arg_value>2025-08-01</arg_value>
</tool_call>"""
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
model_output, request=None
) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
assert len(extracted_tool_calls.tool_calls) == 2
# Check first tool call
assert extracted_tool_calls.tool_calls[0].function.name == "get_weather"
args1 = json.loads(extracted_tool_calls.tool_calls[0].function.arguments)
assert args1["city"] == "Beijing"
assert args1["date"] == "2025-08-01"
# Check second tool call
assert extracted_tool_calls.tool_calls[1].function.name == "get_weather"
args2 = json.loads(extracted_tool_calls.tool_calls[1].function.arguments)
assert args2["city"] == "Shanghai"
assert args2["date"] == "2025-08-01"
# Content should be everything before the first tool call
assert extracted_tool_calls.content == "I will help you get the weather info."
def test_streaming_basic_functionality(glm4_moe_tool_parser):
"""Test basic streaming functionality."""
# Reset streaming state
glm4_moe_tool_parser.current_tool_name_sent = False
glm4_moe_tool_parser.prev_tool_call_arr = []
glm4_moe_tool_parser.current_tool_id = -1
glm4_moe_tool_parser.streamed_args_for_tool = []
# Test with a simple tool call
current_text = """<tool_call>get_weather
<arg_key>city</arg_key>
<arg_value>Beijing</arg_value>
</tool_call>"""
# Mock token IDs for testing
tool_call_start_id = glm4_moe_tool_parser.tool_call_start_token_id or 12345
tool_call_end_id = glm4_moe_tool_parser.tool_call_end_token_id or 12346
result = glm4_moe_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text=current_text,
delta_text="</tool_call>",
previous_token_ids=[],
current_token_ids=[tool_call_start_id, tool_call_end_id],
delta_token_ids=[tool_call_end_id],
request=None,
)
# The result behavior depends on the streaming state
# This test mainly ensures no exceptions are thrown
assert result is None or hasattr(result, "tool_calls") or hasattr(result, "content")
def test_streaming_no_tool_calls(glm4_moe_tool_parser):
"""Test streaming when there are no tool calls."""
current_text = "This is just regular text without any tool calls."
result = glm4_moe_tool_parser.extract_tool_calls_streaming(
previous_text="This is just regular text",
current_text=current_text,
delta_text=" without any tool calls.",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=None,
)
# Should return the delta text as content
assert result is not None
assert hasattr(result, "content")
assert result.content == " without any tool calls."
def test_streaming_with_content_before_tool_calls(glm4_moe_tool_parser):
"""Test streaming when there's content before tool calls."""
# Reset streaming state
glm4_moe_tool_parser.current_tool_name_sent = False
glm4_moe_tool_parser.prev_tool_call_arr = []
glm4_moe_tool_parser.current_tool_id = -1
glm4_moe_tool_parser.streamed_args_for_tool = []
current_text = "I will help you get the weather<tool_call>"
result = glm4_moe_tool_parser.extract_tool_calls_streaming(
previous_text="I will help you",
current_text=current_text,
delta_text="get the weather.<tool_call>",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=None,
)
# Should return content when no tool call tokens are detected
assert result is not None
assert hasattr(result, "content")
assert result.content == "get the weather.<tool_call>"
def test_extract_tool_calls_special_characters(glm4_moe_tool_parser):
"""Test tool calls with special characters and unicode."""
model_output = """<tool_call>send_message
<arg_key>recipient</arg_key>
<arg_value>Amy</arg_value>
<arg_key>message</arg_key>
<arg_value>It is a nice day</arg_value>
<arg_key>priority</arg_key>
<arg_value>high</arg_value>
</tool_call>"""
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
model_output, request=None
) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
assert len(extracted_tool_calls.tool_calls) == 1
assert extracted_tool_calls.tool_calls[0].function.name == "send_message"
args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments)
assert args["recipient"] == "Amy"
assert args["message"] == "It is a nice day"
assert args["priority"] == "high"
def test_extract_tool_calls_incomplete_tool_call(glm4_moe_tool_parser):
"""Test incomplete tool calls (missing closing tag)."""
model_output = """<tool_call>get_weather
<arg_key>city</arg_key>
<arg_value>Beijing</arg_value>
<arg_key>date</arg_key>
<arg_value>2025-08-01</arg_value>"""
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
model_output, request=None
) # type: ignore[arg-type]
# Incomplete tool calls should not be extracted
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
assert extracted_tool_calls.content == model_output

View File

@@ -0,0 +1,308 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import Generator
import partial_json_parser
import pytest
from partial_json_parser.core.options import Allow
from vllm.entrypoints.openai.protocol import DeltaMessage, FunctionCall, ToolCall
from vllm.tokenizers import TokenizerLike, get_tokenizer
from vllm.tokenizers.detokenizer_utils import detokenize_incrementally
from vllm.tool_parsers.jamba_tool_parser import JambaToolParser
MODEL = "ai21labs/Jamba-tiny-dev"
@pytest.fixture(scope="module")
def jamba_tokenizer():
return get_tokenizer(tokenizer_name=MODEL)
@pytest.fixture
def jamba_tool_parser(jamba_tokenizer):
return JambaToolParser(jamba_tokenizer)
def assert_tool_calls(
actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall]
):
assert len(actual_tool_calls) == len(expected_tool_calls)
for actual_tool_call, expected_tool_call in zip(
actual_tool_calls, expected_tool_calls
):
assert isinstance(actual_tool_call.id, str)
assert len(actual_tool_call.id) > 16
assert actual_tool_call.type == "function"
assert actual_tool_call.function == expected_tool_call.function
def stream_delta_message_generator(
jamba_tool_parser: JambaToolParser,
jamba_tokenizer: TokenizerLike,
model_output: str,
) -> Generator[DeltaMessage, None, None]:
all_token_ids = jamba_tokenizer.encode(model_output, add_special_tokens=False)
previous_text = ""
previous_tokens = None
prefix_offset = 0
read_offset = 0
for i, delta_token in enumerate(all_token_ids):
delta_token_ids = [delta_token]
previous_token_ids = all_token_ids[:i]
current_token_ids = all_token_ids[: i + 1]
(new_tokens, delta_text, new_prefix_offset, new_read_offset) = (
detokenize_incrementally(
tokenizer=jamba_tokenizer,
all_input_ids=current_token_ids,
prev_tokens=previous_tokens,
prefix_offset=prefix_offset,
read_offset=read_offset,
skip_special_tokens=False,
spaces_between_special_tokens=True,
)
)
current_text = previous_text + delta_text
delta_message = jamba_tool_parser.extract_tool_calls_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
delta_token_ids,
request=None, # type: ignore[arg-type]
)
if delta_message:
yield delta_message
previous_text = current_text
previous_tokens = (
previous_tokens + new_tokens if previous_tokens else new_tokens
)
prefix_offset = new_prefix_offset
read_offset = new_read_offset
def test_extract_tool_calls_no_tools(jamba_tool_parser):
model_output = "This is a test"
extracted_tool_calls = jamba_tool_parser.extract_tool_calls(
model_output, request=None
) # type: ignore[arg-type]
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
assert extracted_tool_calls.content == model_output
@pytest.mark.parametrize(
ids=[
"single_tool",
"single_tool_with_content",
"parallel_tools",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
(
""" <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
),
)
)
],
None,
),
(
""" Sure! let me call the tool for you.<tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
),
)
)
],
" Sure! let me call the tool for you.",
),
(
""" <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}},\n {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}\n]</tool_calls>""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
),
)
),
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "Orlando", "state": "FL", "unit": "fahrenheit"}
),
)
),
],
None,
),
],
)
def test_extract_tool_calls(
jamba_tool_parser, model_output, expected_tool_calls, expected_content
):
extracted_tool_calls = jamba_tool_parser.extract_tool_calls(
model_output, request=None
) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
assert extracted_tool_calls.content == expected_content
@pytest.mark.parametrize(
ids=[
"no_tools",
"single_tool",
"single_tool_with_content",
"parallel_tools",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
("""This is a test""", [], """This is a test"""),
(
""" <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
),
)
)
],
" ",
),
(
""" Sure! let me call the tool for you.<tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
),
)
)
],
" Sure! let me call the tool for you.",
),
(
""" <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}},\n {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}\n]</tool_calls>""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
),
)
),
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "Orlando", "state": "FL", "unit": "fahrenheit"}
),
)
),
],
" ",
),
],
)
def test_extract_tool_calls_streaming(
jamba_tool_parser,
jamba_tokenizer,
model_output,
expected_tool_calls,
expected_content,
):
other_content: str = ""
function_names: list[str] = []
function_args_strs: list[str] = []
tool_call_idx: int = -1
tool_call_ids: list[str | None] = []
for delta_message in stream_delta_message_generator(
jamba_tool_parser, jamba_tokenizer, model_output
):
# role should never be streamed from tool parser
assert not delta_message.role
if delta_message.content:
other_content += delta_message.content
streamed_tool_calls = delta_message.tool_calls
if streamed_tool_calls and len(streamed_tool_calls) > 0:
# make sure only one diff is present - correct even for parallel
assert len(streamed_tool_calls) == 1
tool_call = streamed_tool_calls[0]
# if a new tool is being called, set up empty arguments
if tool_call.index != tool_call_idx:
tool_call_idx = tool_call.index
function_args_strs.append("")
tool_call_ids.append(None)
# if a tool call ID is streamed, make sure one hasn't been already
if tool_call.id and not tool_call_ids[tool_call.index]:
tool_call_ids[tool_call.index] = tool_call.id
# if parts of the function start being streamed
if tool_call.function:
# if the function name is defined, set it. it should be streamed
# IN ENTIRETY, exactly one time.
if tool_call.function.name:
assert isinstance(tool_call.function.name, str)
function_names.append(tool_call.function.name)
if tool_call.function.arguments:
# make sure they're a string and then add them to the list
assert isinstance(tool_call.function.arguments, str)
function_args_strs[tool_call.index] += tool_call.function.arguments
assert other_content == expected_content
actual_tool_calls = [
ToolCall(
id=tool_call_id,
function=FunctionCall(
name=function_name,
arguments=partial_json_parser.ensure_json(
function_args_str, Allow.OBJ | Allow.STR
),
),
)
for tool_call_id, function_name, function_args_str in zip(
tool_call_ids, function_names, function_args_strs
)
]
assert_tool_calls(actual_tool_calls, expected_tool_calls)

View File

@@ -0,0 +1,924 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
import json
import pytest
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
from vllm.tokenizers import get_tokenizer
from vllm.tool_parsers.kimi_k2_tool_parser import KimiK2ToolParser
# Use a common model that is likely to be available
MODEL = "moonshotai/Kimi-K2-Instruct"
@pytest.fixture(scope="module")
def kimi_k2_tokenizer():
return get_tokenizer(tokenizer_name=MODEL, trust_remote_code=True)
@pytest.fixture
def kimi_k2_tool_parser(kimi_k2_tokenizer):
return KimiK2ToolParser(kimi_k2_tokenizer)
def assert_tool_calls(
actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall]
):
assert len(actual_tool_calls) == len(expected_tool_calls)
for actual_tool_call, expected_tool_call in zip(
actual_tool_calls, expected_tool_calls
):
assert actual_tool_call.type == "function"
assert actual_tool_call.function == expected_tool_call.function
# assert tool call id format: should contain function name and numeric index
# Format can be either "functions.func_name:0" or "func_name:0"
assert actual_tool_call.id.split(":")[-1].isdigit()
assert (
actual_tool_call.id.split(":")[0].split(".")[-1]
== expected_tool_call.function.name
)
def test_extract_tool_calls_no_tools(kimi_k2_tool_parser):
model_output = "This is a test"
extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls(
model_output, request=None
) # type: ignore[arg-type]
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
assert extracted_tool_calls.content == model_output
@pytest.mark.parametrize(
ids=[
"tool_call_with_content_before",
"multi_tool_call_with_content_before",
"concatenated_tool_calls_bug_fix",
"three_concatenated_tool_calls",
"mixed_spacing_tool_calls",
"angle_brackets_in_json",
"newlines_in_json",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
(
"""I'll help you check the weather. <|tool_calls_section_begin|> <|tool_call_begin|>
functions.get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_calls_section_end|>""",
[
ToolCall(
id="functions.get_weather:0",
function=FunctionCall(
name="get_weather",
arguments=json.dumps(
{
"city": "Beijing",
},
),
),
type="function",
)
],
"I'll help you check the weather. ",
),
(
"""I'll help you check the weather. <|tool_calls_section_begin|> <|tool_call_begin|>
functions.get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_call_begin|>
functions.get_weather:1 <|tool_call_argument_begin|> {"city": "Shanghai"} <|tool_call_end|> <|tool_calls_section_end|>""",
[
ToolCall(
id="functions.get_weather:0",
function=FunctionCall(
name="get_weather",
arguments=json.dumps(
{
"city": "Beijing",
},
),
),
type="function",
),
ToolCall(
id="functions.get_weather:1",
function=FunctionCall(
name="get_weather",
arguments=json.dumps(
{
"city": "Shanghai",
},
),
),
type="function",
),
],
"I'll help you check the weather. ",
),
(
"""I'll get the weather and news for LA today. First, let me get the weather using Los Angeles coordinates, and then get the latest news. <|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"latitude": 34.0522, "longitude": -118.2437}<|tool_call_end|><|tool_call_begin|>functions.get_news:1<|tool_call_argument_begin|>{"content": "Los Angeles today"}<|tool_call_end|><|tool_calls_section_end|>""",
[
ToolCall(
id="functions.get_weather:0",
function=FunctionCall(
name="get_weather",
arguments=json.dumps(
{"latitude": 34.0522, "longitude": -118.2437}
),
),
type="function",
),
ToolCall(
id="functions.get_news:1",
function=FunctionCall(
name="get_news",
arguments=json.dumps({"content": "Los Angeles today"}),
),
type="function",
),
],
"I'll get the weather and news for LA today. First, let me get the weather using Los Angeles coordinates, and then get the latest news. ",
),
(
"""I'll help you with multiple tasks. <|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"city": "New York"}<|tool_call_end|><|tool_call_begin|>functions.get_news:1<|tool_call_argument_begin|>{"topic": "technology"}<|tool_call_end|><|tool_call_begin|>functions.send_email:2<|tool_call_argument_begin|>{"to": "user@example.com", "subject": "Daily Update"}<|tool_call_end|><|tool_calls_section_end|>""",
[
ToolCall(
id="functions.get_weather:0",
function=FunctionCall(
name="get_weather",
arguments=json.dumps({"city": "New York"}),
),
type="function",
),
ToolCall(
id="functions.get_news:1",
function=FunctionCall(
name="get_news",
arguments=json.dumps({"topic": "technology"}),
),
type="function",
),
ToolCall(
id="functions.send_email:2",
function=FunctionCall(
name="send_email",
arguments=json.dumps(
{"to": "user@example.com", "subject": "Daily Update"}
),
),
type="function",
),
],
"I'll help you with multiple tasks. ",
),
(
"""Mixed spacing test. <|tool_calls_section_begin|> <|tool_call_begin|> functions.test:0 <|tool_call_argument_begin|> {} <|tool_call_end|><|tool_call_begin|>functions.test2:1<|tool_call_argument_begin|>{}<|tool_call_end|> <|tool_calls_section_end|>""",
[
ToolCall(
id="functions.test:0",
function=FunctionCall(
name="test",
arguments=json.dumps({}),
),
type="function",
),
ToolCall(
id="functions.test2:1",
function=FunctionCall(
name="test2",
arguments=json.dumps({}),
),
type="function",
),
],
"Mixed spacing test. ",
),
(
"""I need to process HTML content. <|tool_calls_section_begin|><|tool_call_begin|>functions.process_html:0<|tool_call_argument_begin|>{"html": "<div>content</div>", "text": "normal text"}<|tool_call_end|><|tool_calls_section_end|>""",
[
ToolCall(
id="functions.process_html:0",
function=FunctionCall(
name="process_html",
arguments=json.dumps(
{"html": "<div>content</div>", "text": "normal text"}
),
),
type="function",
)
],
"I need to process HTML content. ",
),
(
"""I need to process formatted JSON. <|tool_calls_section_begin|><|tool_call_begin|>functions.process_data:0<|tool_call_argument_begin|>{
"name": "test",
"value": 123,
"nested": {
"key": "value"
}
}<|tool_call_end|><|tool_calls_section_end|>""",
[
ToolCall(
id="functions.process_data:0",
function=FunctionCall(
name="process_data",
arguments=json.dumps(
{"name": "test", "value": 123, "nested": {"key": "value"}},
indent=2,
),
),
type="function",
)
],
"I need to process formatted JSON. ",
),
],
)
def test_extract_tool_calls(
kimi_k2_tool_parser, model_output, expected_tool_calls, expected_content
):
extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls(
model_output, request=None
) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
assert extracted_tool_calls.content == expected_content
def test_extract_tool_calls_invalid_json(kimi_k2_tool_parser):
"""we'll return every funcall result"""
model_output = """I'll help you check the weather. <|tool_calls_section_begin|> <|tool_call_begin|>
functions.invalid_get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing" <|tool_call_end|> <|tool_call_begin|>
functions.valid_get_weather:1 <|tool_call_argument_begin|> {"city": "Shanghai"} <|tool_call_end|> <|tool_calls_section_end|>"""
extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls(
model_output, request=None
) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
# Should extract only the valid JSON tool calls
assert len(extracted_tool_calls.tool_calls) == 2
assert extracted_tool_calls.tool_calls[0].function.name == "invalid_get_weather"
assert extracted_tool_calls.tool_calls[1].function.name == "valid_get_weather"
def test_extract_tool_calls_invalid_funcall(kimi_k2_tool_parser):
"""we'll return every funcall result"""
model_output = """I'll help you check the weather. <|tool_calls_section_begin|> <|tool_call_begin|>
functions.invalid_get_weather.0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_call_begin|>
functions.valid_get_weather:1 <|tool_call_argument_begin|> {"city": "Shanghai"} <|tool_call_end|> <|tool_calls_section_end|>"""
extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls(
model_output, request=None
) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
# Should extract only the valid JSON tool calls
assert len(extracted_tool_calls.tool_calls) == 1
assert extracted_tool_calls.tool_calls[0].function.name == "valid_get_weather"
def test_streaming_basic_functionality(kimi_k2_tool_parser):
"""Test basic streaming functionality."""
# Reset streaming state
kimi_k2_tool_parser.current_tool_name_sent = False
kimi_k2_tool_parser.prev_tool_call_arr = []
kimi_k2_tool_parser.current_tool_id = -1
kimi_k2_tool_parser.streamed_args_for_tool = []
# Test with a simple tool call
current_text = """ check the weather. <|tool_calls_section_begin|> <|tool_call_begin|>
functions.get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_calls_section_end|>"""
# First call should handle the initial setup
result = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="I'll help you",
current_text=current_text,
delta_text="<|tool_calls_section_end|>",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=None,
)
# The result might be None or contain tool call information
# This depends on the internal state management
if result is not None and hasattr(result, "tool_calls") and result.tool_calls:
assert len(result.tool_calls) >= 0
def test_streaming_no_tool_calls(kimi_k2_tool_parser):
"""Test streaming when there are no tool calls."""
current_text = "This is just regular text without any tool calls."
result = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="This is just regular text",
current_text=current_text,
delta_text=" without any tool calls.",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=None,
)
# Should return the delta text as content
assert result is not None
assert hasattr(result, "content")
assert result.content == " without any tool calls."
def test_token_leak_between_section_and_tool_begin(kimi_k2_tool_parser):
"""
Test that text between <|tool_calls_section_begin|> and <|tool_call_begin|>
is suppressed and does not leak into reasoning_delta.
This is the main vulnerability being fixed.
"""
kimi_k2_tool_parser.reset_streaming_state()
# Get token IDs for the markers
section_begin_token_id = kimi_k2_tool_parser.vocab.get(
"<|tool_calls_section_begin|>"
)
tool_call_begin_token_id = kimi_k2_tool_parser.vocab.get("<|tool_call_begin|>")
# Simulate streaming sequence:
# Delta 1: "I'll help you with that. "
result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text="I'll help you with that. ",
delta_text="I'll help you with that. ",
previous_token_ids=[],
current_token_ids=[1, 2, 3], # Regular tokens
delta_token_ids=[1, 2, 3],
request=None,
)
assert result1 is not None
assert result1.content == "I'll help you with that. "
# Delta 2: "<|tool_calls_section_begin|>"
prev_ids = [1, 2, 3]
curr_ids = prev_ids + [section_begin_token_id]
result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="I'll help you with that. ",
current_text="I'll help you with that. <|tool_calls_section_begin|>",
delta_text="<|tool_calls_section_begin|>",
previous_token_ids=prev_ids,
current_token_ids=curr_ids,
delta_token_ids=[section_begin_token_id],
request=None,
)
# Section marker should be stripped and suppressed
assert result2 is None or (result2.content is None or result2.content == "")
# Delta 3: " spurious text or tokens " (THE LEAK SCENARIO)
prev_ids = curr_ids
curr_ids = curr_ids + [4, 5]
result3 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="I'll help you with that. <|tool_calls_section_begin|>",
current_text="I'll help you with that. <|tool_calls_section_begin|> spurious text ",
delta_text=" spurious text ",
previous_token_ids=prev_ids,
current_token_ids=curr_ids,
delta_token_ids=[4, 5],
request=None,
)
# CRITICAL: This text should be suppressed, NOT returned as reasoning_delta
assert result3 is None or (result3.content is None or result3.content == "")
# Delta 4: "<|tool_call_begin|>..."
prev_ids = curr_ids
curr_ids = curr_ids + [tool_call_begin_token_id]
_result4 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="I'll help you with that. <|tool_calls_section_begin|> spurious text ",
current_text="I'll help you with that. <|tool_calls_section_begin|> spurious text <|tool_call_begin|>",
delta_text="<|tool_call_begin|>",
previous_token_ids=prev_ids,
current_token_ids=curr_ids,
delta_token_ids=[tool_call_begin_token_id],
request=None,
)
# Now we're in tool call mode, result depends on internal state
# The key is that the spurious text from Delta 3 was not leaked
def test_split_markers_across_deltas(kimi_k2_tool_parser):
"""
Test that markers split across delta chunks are correctly detected
via the rolling buffer mechanism.
"""
kimi_k2_tool_parser.reset_streaming_state()
section_begin_token_id = kimi_k2_tool_parser.vocab.get(
"<|tool_calls_section_begin|>"
)
# Delta 1: "...reasoning<|tool_calls_sec"
_result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="Some reasoning",
current_text="Some reasoning<|tool_calls_sec",
delta_text="<|tool_calls_sec",
previous_token_ids=[1, 2],
current_token_ids=[1, 2, 3], # Partial token
delta_token_ids=[3],
request=None,
)
# Partial token not recognized yet, might be buffered
# Should return as content or None (depends on implementation)
# Delta 2: "tion_begin|> " (completes the marker)
_result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="Some reasoning<|tool_calls_sec",
current_text="Some reasoning<|tool_calls_section_begin|> ",
delta_text="tion_begin|> ",
previous_token_ids=[1, 2, 3],
current_token_ids=[1, 2, section_begin_token_id, 4],
delta_token_ids=[section_begin_token_id, 4],
request=None,
)
# Now the complete marker should be detected via buffer
# The parser should enter tool section mode
assert kimi_k2_tool_parser.in_tool_section is True
def test_marker_variants(kimi_k2_tool_parser):
"""Test that both singular and plural marker variants are recognized."""
kimi_k2_tool_parser.reset_streaming_state()
# Test singular variant: <|tool_call_section_begin|> (note: singular "call")
singular_token_id = kimi_k2_tool_parser.vocab.get("<|tool_call_section_begin|>")
if singular_token_id is not None: # Only test if tokenizer supports it
_result = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="Reasoning ",
current_text="Reasoning <|tool_call_section_begin|>",
delta_text="<|tool_call_section_begin|>",
previous_token_ids=[1, 2],
current_token_ids=[1, 2, singular_token_id],
delta_token_ids=[singular_token_id],
request=None,
)
# Should enter tool section mode with singular variant too
assert kimi_k2_tool_parser.in_tool_section is True
def test_reentry_to_reasoning_after_tool_section(kimi_k2_tool_parser):
"""
Test that after exiting a tool section with <|tool_calls_section_end|>,
subsequent text is correctly returned as reasoning content.
"""
kimi_k2_tool_parser.reset_streaming_state()
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>")
# Enter tool section
_result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text="<|tool_calls_section_begin|>",
delta_text="<|tool_calls_section_begin|>",
previous_token_ids=[],
current_token_ids=[section_begin_id],
delta_token_ids=[section_begin_id],
request=None,
)
assert kimi_k2_tool_parser.in_tool_section is True
# Exit tool section
_result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="<|tool_calls_section_begin|>",
current_text="<|tool_calls_section_begin|><|tool_calls_section_end|>",
delta_text="<|tool_calls_section_end|>",
previous_token_ids=[section_begin_id],
current_token_ids=[section_begin_id, section_end_id],
delta_token_ids=[section_end_id],
request=None,
)
assert kimi_k2_tool_parser.in_tool_section is False
# Subsequent reasoning text should be returned normally
result3 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="<|tool_calls_section_begin|><|tool_calls_section_end|>",
current_text="<|tool_calls_section_begin|><|tool_calls_section_end|> More reasoning",
delta_text=" More reasoning",
previous_token_ids=[section_begin_id, section_end_id],
current_token_ids=[section_begin_id, section_end_id, 10, 11],
delta_token_ids=[10, 11],
request=None,
)
assert result3 is not None
assert result3.content == " More reasoning"
def test_empty_tool_section(kimi_k2_tool_parser):
"""Test an empty tool section (begin immediately followed by end)."""
kimi_k2_tool_parser.reset_streaming_state()
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>")
# Section begin
_result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="Reasoning ",
current_text="Reasoning <|tool_calls_section_begin|>",
delta_text="<|tool_calls_section_begin|>",
previous_token_ids=[1],
current_token_ids=[1, section_begin_id],
delta_token_ids=[section_begin_id],
request=None,
)
# Immediate section end
_result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="Reasoning <|tool_calls_section_begin|>",
current_text="Reasoning <|tool_calls_section_begin|><|tool_calls_section_end|>",
delta_text="<|tool_calls_section_end|>",
previous_token_ids=[1, section_begin_id],
current_token_ids=[1, section_begin_id, section_end_id],
delta_token_ids=[section_end_id],
request=None,
)
# Should exit cleanly without errors
assert kimi_k2_tool_parser.in_tool_section is False
def test_malformed_tool_section_recovery(kimi_k2_tool_parser):
"""
Test that the parser recovers from a malformed tool section
that never closes properly.
"""
kimi_k2_tool_parser.reset_streaming_state()
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
# Enter tool section
_result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text="<|tool_calls_section_begin|>",
delta_text="<|tool_calls_section_begin|>",
previous_token_ids=[],
current_token_ids=[section_begin_id],
delta_token_ids=[section_begin_id],
request=None,
)
assert kimi_k2_tool_parser.in_tool_section is True
# Simulate a lot of text without proper tool calls or section end
# This should trigger the error recovery mechanism
large_text = "x" * 10000 # Exceeds max_section_chars
result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="<|tool_calls_section_begin|>",
current_text="<|tool_calls_section_begin|>" + large_text,
delta_text=large_text,
previous_token_ids=[section_begin_id],
current_token_ids=[section_begin_id] + list(range(100, 100 + len(large_text))),
delta_token_ids=list(range(100, 100 + len(large_text))),
request=None,
)
# Parser should have force-exited the tool section
assert kimi_k2_tool_parser.in_tool_section is False
# And returned the content as reasoning
assert result2 is not None
assert result2.content == large_text
def test_state_reset(kimi_k2_tool_parser):
"""Test that reset_streaming_state() properly clears all state."""
# Put parser in a complex state
kimi_k2_tool_parser.in_tool_section = True
kimi_k2_tool_parser.token_buffer = "some buffer"
kimi_k2_tool_parser.current_tool_id = 5
kimi_k2_tool_parser.prev_tool_call_arr = [{"id": "test"}]
kimi_k2_tool_parser.section_char_count = 1000
# Reset
kimi_k2_tool_parser.reset_streaming_state()
# Verify all state is cleared
assert kimi_k2_tool_parser.in_tool_section is False
assert kimi_k2_tool_parser.token_buffer == ""
assert kimi_k2_tool_parser.current_tool_id == -1
assert kimi_k2_tool_parser.prev_tool_call_arr == []
assert kimi_k2_tool_parser.section_char_count == 0
assert kimi_k2_tool_parser.current_tool_name_sent is False
assert kimi_k2_tool_parser.streamed_args_for_tool == []
def test_section_begin_noise_tool_begin_same_chunk(kimi_k2_tool_parser):
"""
Test that begin→noise→tool_begin within the SAME chunk suppresses
the noise text correctly (not just across chunks).
"""
kimi_k2_tool_parser.reset_streaming_state()
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
tool_call_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_call_begin|>")
# Single delta containing: section_begin + spurious text + tool_call_begin
combined_text = "<|tool_calls_section_begin|> noise text <|tool_call_begin|>"
result = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="Reasoning ",
current_text="Reasoning " + combined_text,
delta_text=combined_text,
previous_token_ids=[1, 2],
current_token_ids=[1, 2, section_begin_id, 3, 4, tool_call_begin_id],
delta_token_ids=[section_begin_id, 3, 4, tool_call_begin_id],
request=None,
)
# The noise text should NOT leak into content
# Result should either be None/empty or start tool call parsing
if result is not None and result.content is not None:
# If content is returned, it should not contain the noise
assert "noise text" not in result.content
assert result.content == "" or result.content.strip() == ""
def test_stream_ends_without_section_end_marker(kimi_k2_tool_parser):
"""
Test that if the stream ends (EOF) without a proper section end marker,
the parser doesn't leak text, doesn't crash, and resets state cleanly.
"""
kimi_k2_tool_parser.reset_streaming_state()
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
# Enter tool section
_result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text="<|tool_calls_section_begin|>",
delta_text="<|tool_calls_section_begin|>",
previous_token_ids=[],
current_token_ids=[section_begin_id],
delta_token_ids=[section_begin_id],
request=None,
)
assert kimi_k2_tool_parser.in_tool_section is True
# Some content in tool section
result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="<|tool_calls_section_begin|>",
current_text="<|tool_calls_section_begin|> partial content",
delta_text=" partial content",
previous_token_ids=[section_begin_id],
current_token_ids=[section_begin_id, 10, 11],
delta_token_ids=[10, 11],
request=None,
)
# Content should be suppressed
assert result2.content == "" or result2.content is None
# Stream ends (EOF) - no more deltas, no section_end marker
# Simulate this by manually checking state and resetting
# (In real usage, the request handler would call reset_streaming_state)
assert kimi_k2_tool_parser.in_tool_section is True # Still in section
# Reset state (as would happen between requests)
kimi_k2_tool_parser.reset_streaming_state()
# Verify clean slate
assert kimi_k2_tool_parser.in_tool_section is False
assert kimi_k2_tool_parser.token_buffer == ""
# Next request should work normally
result3 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text="New reasoning",
delta_text="New reasoning",
previous_token_ids=[],
current_token_ids=[20, 21],
delta_token_ids=[20, 21],
request=None,
)
assert result3 is not None
assert result3.content == "New reasoning"
def test_same_chunk_begin_and_end_markers(kimi_k2_tool_parser):
"""
CRITICAL TEST: Verify that when both section_begin and section_end
markers appear in the SAME chunk, the parser correctly:
1. Enters the tool section
2. Immediately exits the tool section
3. Does NOT get stuck in in_tool_section=True state
This tests the bug fix where elif was changed to if to handle
both state transitions in a single delta.
"""
kimi_k2_tool_parser.reset_streaming_state()
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>")
# Single chunk with both markers (e.g., empty tool section)
combined_delta = "<|tool_calls_section_begin|><|tool_calls_section_end|>"
result = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="Some reasoning ",
current_text="Some reasoning " + combined_delta,
delta_text=combined_delta,
previous_token_ids=[1, 2],
current_token_ids=[1, 2, section_begin_id, section_end_id],
delta_token_ids=[section_begin_id, section_end_id],
request=None,
)
# CRITICAL: Parser should NOT be stuck in tool section
assert kimi_k2_tool_parser.in_tool_section is False, (
"Parser stuck in tool section after processing both begin/end in same chunk. "
"This indicates the elif bug was not fixed."
)
# Result should be empty or contain only stripped content
assert result is not None
assert result.content == "" or result.content is None
# Verify subsequent content streams correctly (not suppressed)
result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="Some reasoning " + combined_delta,
current_text="Some reasoning " + combined_delta + " More reasoning",
delta_text=" More reasoning",
previous_token_ids=[1, 2, section_begin_id, section_end_id],
current_token_ids=[1, 2, section_begin_id, section_end_id, 10, 11],
delta_token_ids=[10, 11],
request=None,
)
# This content should NOT be suppressed (we're out of tool section)
assert result2 is not None
assert result2.content == " More reasoning"
def test_same_chunk_begin_content_end_markers(kimi_k2_tool_parser):
"""
Test the same-chunk scenario with actual content between markers.
Example: <|tool_calls_section_begin|> text <|tool_calls_section_end|>
all arriving in one delta. The key is that the state machine correctly
transitions in and out within the same chunk.
"""
kimi_k2_tool_parser.reset_streaming_state()
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>")
# Chunk with begin, some whitespace/noise, and end all together
# This simulates a tool section that opens and closes in the same chunk
combined_delta = "<|tool_calls_section_begin|> <|tool_calls_section_end|>"
_result = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="Reasoning ",
current_text="Reasoning " + combined_delta,
delta_text=combined_delta,
previous_token_ids=[1],
current_token_ids=[1, section_begin_id, 100, section_end_id],
delta_token_ids=[section_begin_id, 100, section_end_id],
request=None,
)
# Parser should exit cleanly (not stuck in tool section)
assert kimi_k2_tool_parser.in_tool_section is False
# Verify the fix: next content should stream normally, not be suppressed
result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="Reasoning " + combined_delta,
current_text="Reasoning " + combined_delta + " Done",
delta_text=" Done",
previous_token_ids=[1, section_begin_id, 100, section_end_id],
current_token_ids=[1, section_begin_id, 100, section_end_id, 200],
delta_token_ids=[200],
request=None,
)
# Content after section should be returned (not suppressed)
assert result2 is not None
assert result2.content == " Done"
def test_tool_call_end_and_section_end_same_chunk(kimi_k2_tool_parser):
"""
CRITICAL TEST (P1): Verify that when both <|tool_call_end|> and
<|tool_calls_section_end|> appear in the SAME chunk, the parser:
1. Processes the tool_call_end first (emits final arguments)
2. THEN exits the section
3. Does NOT drop the final tool call update
4. Does NOT leak special tokens into reasoning
This tests the deferred section exit fix.
"""
kimi_k2_tool_parser.reset_streaming_state()
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>")
tool_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_call_begin|>")
tool_end_id = kimi_k2_tool_parser.vocab.get("<|tool_call_end|>")
# Simulate a streaming sequence for a SHORT tool call (all in one chunk):
# 1. Reasoning text
result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text="Let me help. ",
delta_text="Let me help. ",
previous_token_ids=[],
current_token_ids=[1, 2],
delta_token_ids=[1, 2],
request=None,
)
assert result1 is not None
assert result1.content == "Let me help. "
# 2. Section begin
_result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="Let me help. ",
current_text="Let me help. <|tool_calls_section_begin|>",
delta_text="<|tool_calls_section_begin|>",
previous_token_ids=[1, 2],
current_token_ids=[1, 2, section_begin_id],
delta_token_ids=[section_begin_id],
request=None,
)
assert kimi_k2_tool_parser.in_tool_section is True
# 3. Tool call begin + full content + tool_end + section_end ALL IN ONE CHUNK
# This is the critical scenario for short tool calls
combined = (
'<|tool_call_begin|>get_weather:0 <|tool_call_argument_begin|> {"city": "Paris"} '
"<|tool_call_end|><|tool_calls_section_end|>"
)
# Build up the previous text gradually to simulate realistic streaming
prev_text = "Let me help. <|tool_calls_section_begin|>"
curr_text = prev_text + combined
result3 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text=prev_text,
current_text=curr_text,
delta_text=combined,
previous_token_ids=[1, 2, section_begin_id],
current_token_ids=[
1,
2,
section_begin_id,
tool_begin_id,
10,
11,
12,
tool_end_id,
section_end_id,
],
delta_token_ids=[tool_begin_id, 10, 11, 12, tool_end_id, section_end_id],
request=None,
)
# CRITICAL: Parser should have exited section AFTER processing tool
assert kimi_k2_tool_parser.in_tool_section is False
# Tool call should have been emitted (not dropped)
# The result might be the tool name or None depending on state, but
# importantly, it shouldn't be returning the literal tokens as content
if result3 is not None and result3.content is not None:
# Verify no special tokens leaked into content
assert "<|tool_call_end|>" not in result3.content
assert "<|tool_calls_section_end|>" not in result3.content
# 4. Verify subsequent content streams normally
result4 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text=curr_text,
current_text=curr_text + " Done",
delta_text=" Done",
previous_token_ids=[
1,
2,
section_begin_id,
tool_begin_id,
10,
11,
12,
tool_end_id,
section_end_id,
],
current_token_ids=[
1,
2,
section_begin_id,
tool_begin_id,
10,
11,
12,
tool_end_id,
section_end_id,
20,
],
delta_token_ids=[20],
request=None,
)
# Content after tool section should stream normally
assert result4 is not None
assert result4.content == " Done"

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,860 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import Generator
import partial_json_parser
import pytest
from mistral_common.protocol.instruct.messages import AssistantMessage
from mistral_common.protocol.instruct.request import InstructRequest
from mistral_common.protocol.instruct.tool_calls import FunctionCall, ToolCall
from partial_json_parser.core.options import Allow
from vllm.entrypoints.openai.protocol import DeltaMessage, DeltaToolCall
from vllm.tokenizers import TokenizerLike, get_tokenizer
from vllm.tokenizers.detokenizer_utils import detokenize_incrementally
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.tool_parsers.mistral_tool_parser import MistralToolParser
@pytest.fixture(scope="module")
def mistral_pre_v11_tokenizer():
MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
return get_tokenizer(tokenizer_name=MODEL)
@pytest.fixture(scope="module")
def mistral_tokenizer():
MODEL = "mistralai/Mistral-Small-3.2-24B-Instruct-2506"
return get_tokenizer(tokenizer_name=MODEL, tokenizer_mode="mistral")
@pytest.fixture
def mistral_pre_v11_tool_parser(mistral_pre_v11_tokenizer):
return MistralToolParser(mistral_pre_v11_tokenizer)
@pytest.fixture
def mistral_tool_parser(mistral_tokenizer):
return MistralToolParser(mistral_tokenizer)
def assert_tool_calls(
actual_tool_calls: list[ToolCall] | list[DeltaToolCall],
expected_tool_calls: list[ToolCall],
):
assert len(actual_tool_calls) == len(expected_tool_calls)
for actual_tool_call, expected_tool_call in zip(
actual_tool_calls, expected_tool_calls
):
assert isinstance(actual_tool_call.id, str)
assert len(actual_tool_call.id) == 9
if isinstance(actual_tool_call, ToolCall):
assert actual_tool_call.type == "function"
elif isinstance(actual_tool_call, DeltaToolCall):
assert actual_tool_call.function is not None
assert actual_tool_call.function.name is not None
assert actual_tool_call.function.arguments is not None
assert actual_tool_call.function is not None
assert actual_tool_call.function.name == expected_tool_call.function.name, (
f"got wrong function name:${actual_tool_call.function.name}"
)
assert (
actual_tool_call.function.arguments == expected_tool_call.function.arguments
), f"got wrong function argument:${actual_tool_call.function.arguments}"
def fix_tool_call_tokenization(
tokens: list[int],
mistral_tool_parser: MistralToolParser,
mistral_tokenizer: TokenizerLike,
):
"""
Replaces the textual token sequence for [TOOL_CALLS]
with its single special token ID.
"""
textual_tool_call_token_ids = mistral_tokenizer.encode(
text=mistral_tool_parser.bot_token,
add_special_tokens=False,
)
# textual_tool_call_token_ids must not contain special tokens like bos, eos etc
special_tool_call_token_ids = [mistral_tool_parser.bot_token_id]
# If the input is too short to contain the sequence, no replacement is possible
if not tokens or len(tokens) < len(textual_tool_call_token_ids):
return tokens
result_tokens = []
i = 0
target_len = len(textual_tool_call_token_ids)
while i < len(tokens):
# Check if the slice from the current position matches the target sequence
if tokens[i : i + target_len] == textual_tool_call_token_ids:
# If it matches, add the replacement and jump the index forward
result_tokens.extend(special_tool_call_token_ids)
i += target_len
else:
# Otherwise, just add the current token and move to the next one
result_tokens.append(tokens[i])
i += 1
return result_tokens
def stream_delta_message_generator(
mistral_tool_parser: MistralToolParser,
mistral_tokenizer: TokenizerLike,
model_output: str | None,
tools: list[tuple[str, str]] | None,
) -> Generator[DeltaMessage, None, None]:
if (
isinstance(mistral_tokenizer, MistralTokenizer)
and mistral_tokenizer.version >= 11
):
# With the newer versions of the tokenizer,
# we cannot tokenize free text
# so we need to create a list of messages to get tokenized
assert tools is not None
assistant_msg = AssistantMessage(
tool_calls=[
ToolCall(
function=FunctionCall(
name=name,
arguments=arg,
)
)
for (name, arg) in tools
],
)
request = InstructRequest(
messages=[assistant_msg],
)
all_token_ids = mistral_tokenizer.instruct.encode_instruct(request).tokens
else:
# Older versions of the tokenizer are
# able to encode directly the model's output (free text) into tokens
assert model_output is not None
all_token_ids = mistral_tokenizer.encode(model_output, add_special_tokens=False)
all_token_ids = fix_tool_call_tokenization(
all_token_ids, mistral_tool_parser, mistral_tokenizer
)
previous_text = ""
previous_tokens = None
prefix_offset = 0
read_offset = 0
for i, delta_token in enumerate(all_token_ids):
delta_token_ids = [delta_token]
previous_token_ids = all_token_ids[:i]
current_token_ids = all_token_ids[: i + 1]
(new_tokens, delta_text, new_prefix_offset, new_read_offset) = (
detokenize_incrementally(
tokenizer=mistral_tokenizer,
all_input_ids=current_token_ids,
prev_tokens=previous_tokens,
prefix_offset=prefix_offset,
read_offset=read_offset,
skip_special_tokens=isinstance(mistral_tokenizer, MistralTokenizer),
spaces_between_special_tokens=True,
)
)
current_text = previous_text + delta_text
delta_message = mistral_tool_parser.extract_tool_calls_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
delta_token_ids,
request=None, # type: ignore[arg-type]
)
if delta_message:
yield delta_message
previous_text = current_text
previous_tokens = (
previous_tokens + new_tokens if previous_tokens else new_tokens
)
prefix_offset = new_prefix_offset
read_offset = new_read_offset
def test_extract_tool_calls_no_tools(mistral_pre_v11_tool_parser):
model_output = "This is a test"
extracted_tool_calls = mistral_pre_v11_tool_parser.extract_tool_calls(
model_output, request=None
) # type: ignore[arg-type]
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
assert extracted_tool_calls.content == model_output
@pytest.mark.parametrize(
ids=[
"single_tool_add",
"single_tool_weather",
"argument_before_name",
"argument_before_name_and_name_in_argument",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
(
"""[TOOL_CALLS][{"name": "add", "arguments":{"a": 3.5, "b": 4}}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="add", arguments=json.dumps({"a": 3.5, "b": 4})
)
)
],
None,
),
(
"""[TOOL_CALLS] [{"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
),
)
)
],
None,
),
(
"""[TOOL_CALLS] [{"arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
),
)
)
],
None,
),
(
"""[TOOL_CALLS] [{"arguments":{"name": "John Doe"}, "name": "get_age"}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_age",
arguments=json.dumps(
{
"name": "John Doe",
}
),
)
)
],
None,
),
],
)
def test_extract_tool_calls_pre_v11_tokenizer(
mistral_pre_v11_tool_parser, model_output, expected_tool_calls, expected_content
):
extracted_tool_calls = mistral_pre_v11_tool_parser.extract_tool_calls(
model_output, request=None
) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
assert extracted_tool_calls.content == expected_content
@pytest.mark.parametrize(
ids=[
"single_tool_add",
"single_tool_weather",
"multiple_tool_calls",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
(
"""[TOOL_CALLS]add_this_and_that{"a": 3.5, "b": 4}""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="add_this_and_that",
arguments=json.dumps({"a": 3.5, "b": 4}),
)
)
],
None,
),
(
"""[TOOL_CALLS]get_current_weather{"city": "San Francisco", "state": "CA", "unit": "celsius"}""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
),
)
)
],
None,
),
(
"""[TOOL_CALLS]add{"a": 3.5, "b": 4}[TOOL_CALLS]multiply{"a": 3, "b": 6}""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="add", arguments=json.dumps({"a": 3.5, "b": 4})
)
),
ToolCall(
function=FunctionCall(
name="multiply", arguments=json.dumps({"a": 3, "b": 6})
)
),
],
None,
),
],
)
def test_extract_tool_calls(
mistral_tool_parser, model_output, expected_tool_calls, expected_content
):
extracted_tool_calls = mistral_tool_parser.extract_tool_calls(
model_output, request=None
) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
assert extracted_tool_calls.content == expected_content
def _test_extract_tool_calls_streaming(
tool_parser, tokenizer, model_output, tools, expected_tool_calls, expected_content
):
other_content: str = ""
function_names: list[str] = []
function_args_strs: list[str] = []
tool_call_idx: int = -1
tool_call_ids: list[str | None] = []
for delta_message in stream_delta_message_generator(
tool_parser, tokenizer, model_output, tools
):
# role should never be streamed from tool parser
assert not delta_message.role
if delta_message.content:
other_content += delta_message.content
streamed_tool_calls = delta_message.tool_calls
if streamed_tool_calls and len(streamed_tool_calls) > 0:
# make sure only one diff is present - correct even for parallel
assert len(streamed_tool_calls) == 1
tool_call = streamed_tool_calls[0]
assert len(tool_parser.prev_tool_call_arr) > 0
# if a new tool is being called, set up empty arguments
if tool_call.index != tool_call_idx:
tool_call_idx = tool_call.index
function_args_strs.append("")
tool_call_ids.append(None)
# if a tool call ID is streamed, make sure one hasn't been already
if tool_call.id and not tool_call_ids[tool_call.index]:
tool_call_ids[tool_call.index] = tool_call.id
# if parts of the function start being streamed
if tool_call.function:
# if the function name is defined, set it. it should be streamed
# IN ENTIRETY, exactly one time.
if tool_call.function.name:
assert isinstance(tool_call.function.name, str)
function_names.append(tool_call.function.name)
if tool_call.function.arguments:
# make sure they're a string and then add them to the list
assert isinstance(tool_call.function.arguments, str)
function_args_strs[tool_call.index] += tool_call.function.arguments
assert other_content == expected_content
actual_tool_calls = [
ToolCall(
id=tool_call_id,
function=FunctionCall(
name=function_name,
arguments=partial_json_parser.ensure_json(
function_args_str, Allow.OBJ | Allow.STR
),
),
)
for tool_call_id, function_name, function_args_str in zip(
tool_call_ids, function_names, function_args_strs
)
]
assert_tool_calls(actual_tool_calls, expected_tool_calls)
@pytest.mark.parametrize(
ids=[
"no_tools",
"single_tool_add",
"single_tool_add_strings",
"single_tool_weather",
"argument_before_name",
"argument_before_name_and_name_in_argument",
"multiple_tools",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
("""This is a test""", [], """This is a test"""),
(
"""[TOOL_CALLS] [ {"name":"add" , "arguments" : {"a": 3, "b": 4} } ]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="add", arguments=json.dumps({"a": 3, "b": 4})
)
)
],
"",
),
(
"""[TOOL_CALLS] [{"name": "add", "arguments":{"a": "3", "b": "4"}}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="add", arguments=json.dumps({"a": "3", "b": "4"})
)
)
],
"",
),
(
"""[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
),
)
)
],
"",
),
(
"""[TOOL_CALLS] [{"arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
),
)
)
],
"",
),
(
"""[TOOL_CALLS] [{"arguments": {"name": "John Doe"}, "name": "get_age"}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_age",
arguments=json.dumps(
{
"name": "John Doe",
}
),
)
)
],
"",
),
(
"""[TOOL_CALLS] [{"name": "add", "arguments": {"a": 3.5, "b": 4}}, {"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="add", arguments=json.dumps({"a": 3.5, "b": 4})
)
),
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
),
)
),
],
"",
),
],
)
def test_extract_tool_calls_streaming_pre_v11_tokenizer(
mistral_pre_v11_tool_parser,
mistral_pre_v11_tokenizer,
model_output,
expected_tool_calls,
expected_content,
):
_test_extract_tool_calls_streaming(
mistral_pre_v11_tool_parser,
mistral_pre_v11_tokenizer,
model_output,
None,
expected_tool_calls,
expected_content,
)
@pytest.mark.parametrize(
ids=[
"single_tool_add",
"single_tool_add_strings",
"multiple_tools",
],
argnames=["tools", "expected_tool_calls", "expected_content"],
argvalues=[
(
[("add", '{"a": 3, "b": 4}')],
# [TOOL_CALLS]add{"a": 3, "b": 4}
[
ToolCall(
function=FunctionCall(
name="add", arguments=json.dumps({"a": 3, "b": 4})
)
)
],
"",
),
(
[("add_two_strings", '{"a": "3", "b": "4"}')],
# [TOOL_CALLS]add_two_strings{"a": "3", "b": "4"}
[
ToolCall(
function=FunctionCall(
name="add_two_strings",
arguments=json.dumps({"a": "3", "b": "4"}),
)
)
],
"",
),
(
[
("add", '{"a": 3.5, "b": 4}'),
(
"get_current_weather",
'{"city": "San Francisco", "state": "CA", "unit": "celsius"}', # noqa: E501
),
],
# [TOOL_CALLS]add{"a": 3.5, "b": 4}[TOOL_CALLS]get_current_weather{"city": "San Francisco", "state": "CA", "unit": "celsius"} # noqa: E501
[
ToolCall(
function=FunctionCall(
name="add", arguments=json.dumps({"a": 3.5, "b": 4})
)
),
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
),
)
),
],
"",
),
],
)
def test_extract_tool_calls_streaming(
mistral_tool_parser,
mistral_tokenizer,
tools,
expected_tool_calls,
expected_content,
):
_test_extract_tool_calls_streaming(
mistral_tool_parser,
mistral_tokenizer,
None,
tools,
expected_tool_calls,
expected_content,
)
@pytest.mark.parametrize(
ids=[
"single_tool_add",
"single_tool_weather",
"multiple_tool_calls",
"content_before_tool",
"complex",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
(
"""[TOOL_CALLS]add_this_and_that{"a": 3.5, "b": 4}""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="add_this_and_that",
arguments=json.dumps({"a": 3.5, "b": 4}),
)
)
],
"",
),
(
"""[TOOL_CALLS]get_current_weather{"city": "San Francisco", "state": "CA", "unit": "celsius"}""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
),
)
)
],
"",
),
(
"""[TOOL_CALLS]add{"a": 3.5, "b": 4}[TOOL_CALLS]multiply{"a": 3, "b": 6}""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="add", arguments=json.dumps({"a": 3.5, "b": 4})
)
),
ToolCall(
function=FunctionCall(
name="multiply", arguments=json.dumps({"a": 3, "b": 6})
)
),
],
"",
),
(
# Additional content should not be after the tool calls
"""bla[TOOL_CALLS]add_this_and_that{"a": 3.5, "b": 4}""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="add_this_and_that",
arguments=json.dumps({"a": 3.5, "b": 4}),
)
)
],
"bla",
),
(
# Complex
"""[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')"}""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="bash",
arguments=json.dumps(
{"command": "print(\"hello world!\")\nre.compile(r'{}')"}
),
)
)
],
"",
),
],
)
def test_extract_tool_calls_streaming_one_chunk(
mistral_tool_parser,
mistral_tokenizer,
model_output,
expected_tool_calls,
expected_content,
):
if isinstance(mistral_tokenizer, MistralTokenizer):
all_token_ids = mistral_tokenizer.encode(model_output)
else:
all_token_ids = mistral_tokenizer.encode(model_output, add_special_tokens=False)
all_token_ids = fix_tool_call_tokenization(
all_token_ids, mistral_tool_parser, mistral_tokenizer
)
delta_message = mistral_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text=model_output,
delta_text=model_output,
previous_token_ids=[],
current_token_ids=all_token_ids,
delta_token_ids=all_token_ids,
request=None,
) # type: ignore[arg-type]
assert isinstance(delta_message, DeltaMessage)
assert len(delta_message.tool_calls) == len(expected_tool_calls)
assert_tool_calls(delta_message.tool_calls, expected_tool_calls)
if delta_message.content is None:
assert expected_content == ""
else:
assert delta_message.content == expected_content
@pytest.mark.parametrize(
ids=[
"no_tools",
"single_tool_add",
"single_tool_add_strings",
"single_tool_weather",
"argument_before_name",
"argument_before_name_and_name_in_argument",
"multiple_tools",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
("""This is a test""", [], """This is a test"""),
(
"""[TOOL_CALLS] [ {"name":"add" , "arguments" : {"a": 3, "b": 4} } ]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="add", arguments=json.dumps({"a": 3, "b": 4})
)
)
],
"",
),
(
"""[TOOL_CALLS] [{"name": "add", "arguments":{"a": "3", "b": "4"}}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="add", arguments=json.dumps({"a": "3", "b": "4"})
)
)
],
"",
),
(
"""[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
),
)
)
],
"",
),
(
"""[TOOL_CALLS] [{"arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
),
)
)
],
"",
),
(
"""[TOOL_CALLS] [{"arguments": {"name": "John Doe"}, "name": "get_age"}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_age",
arguments=json.dumps(
{
"name": "John Doe",
}
),
)
)
],
"",
),
(
"""[TOOL_CALLS] [{"arguments": {"a": 3.5, "b": 4}, "name": "add"}, {"arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="add", arguments=json.dumps({"a": 3.5, "b": 4})
)
),
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
),
)
),
],
"",
),
],
)
def test_extract_tool_calls_streaming_pre_v11_tokenizer_one_chunk(
mistral_pre_v11_tool_parser,
mistral_pre_v11_tokenizer,
model_output,
expected_tool_calls,
expected_content,
):
if isinstance(mistral_pre_v11_tokenizer, MistralTokenizer):
all_token_ids = mistral_pre_v11_tokenizer.encode(model_output)
else:
all_token_ids = mistral_pre_v11_tokenizer.encode(
model_output, add_special_tokens=False
)
all_token_ids = fix_tool_call_tokenization(
all_token_ids, mistral_pre_v11_tool_parser, mistral_pre_v11_tokenizer
)
delta_message = mistral_pre_v11_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text=model_output,
delta_text=model_output,
previous_token_ids=[],
current_token_ids=all_token_ids,
delta_token_ids=all_token_ids,
request=None,
) # type: ignore[arg-type]
assert isinstance(delta_message, DeltaMessage)
assert len(delta_message.tool_calls) == len(expected_tool_calls)
assert_tool_calls(delta_message.tool_calls, expected_tool_calls)
if delta_message.content is None:
assert expected_content == ""
else:
assert delta_message.content == expected_content

View File

@@ -0,0 +1,263 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import pytest
from openai_harmony import (
Conversation,
DeveloperContent,
HarmonyEncodingName,
Message,
Role,
SystemContent,
load_harmony_encoding,
)
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
from vllm.tokenizers import get_tokenizer
from vllm.tool_parsers.openai_tool_parser import OpenAIToolParser
MODEL = "gpt2"
@pytest.fixture(scope="module")
def openai_tokenizer():
# The parser does not use the tokenizer, but the constructor requires it.
return get_tokenizer(MODEL)
@pytest.fixture
def openai_tool_parser(openai_tokenizer):
return OpenAIToolParser(openai_tokenizer)
@pytest.fixture(scope="module")
def harmony_encoding():
return load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
def assert_tool_calls(
actual_tool_calls: list[ToolCall],
expected_tool_calls: list[ToolCall],
):
assert len(actual_tool_calls) == len(expected_tool_calls)
for actual_tool_call, expected_tool_call in zip(
actual_tool_calls, expected_tool_calls
):
assert isinstance(actual_tool_call.id, str)
assert len(actual_tool_call.id) > 16 # Default from protocol.py
assert actual_tool_call.type == "function"
assert actual_tool_call.function == expected_tool_call.function
def test_extract_tool_calls_no_tools(openai_tool_parser, harmony_encoding):
convo = Conversation.from_messages(
[
Message.from_role_and_content(
Role.SYSTEM,
SystemContent.new(),
),
Message.from_role_and_content(
Role.DEVELOPER,
DeveloperContent.new().with_instructions("Talk like a pirate!"),
),
Message.from_role_and_content(Role.USER, "Arrr, how be you?"),
Message.from_role_and_content(
Role.ASSISTANT, "This is a test"
).with_channel("final"),
]
)
token_ids = harmony_encoding.render_conversation_for_completion(
convo, Role.ASSISTANT
)
extracted_info = openai_tool_parser.extract_tool_calls(
"",
request=None,
token_ids=token_ids,
)
assert not extracted_info.tools_called
assert extracted_info.tool_calls == []
assert extracted_info.content == "This is a test"
@pytest.mark.parametrize(
"tool_args",
[
'{"location": "Tokyo"}',
'{\n"location": "Tokyo"\n}',
],
)
def test_extract_tool_calls_single_tool(
openai_tool_parser, harmony_encoding, tool_args
):
convo = Conversation.from_messages(
[
Message.from_role_and_content(Role.USER, "What is the weather in Tokyo?"),
Message.from_role_and_content(
Role.ASSISTANT,
'User asks: "What is the weather in Tokyo?" We need to use get_current_weather tool.', # noqa: E501
).with_channel("analysis"),
Message.from_role_and_content(Role.ASSISTANT, tool_args)
.with_channel("commentary")
.with_recipient("functions.get_current_weather")
.with_content_type("json"),
]
)
token_ids = harmony_encoding.render_conversation_for_completion(
convo, Role.ASSISTANT
)
extracted_info = openai_tool_parser.extract_tool_calls(
"",
request=None,
token_ids=token_ids,
)
assert extracted_info.tools_called
expected_tool_calls = [
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({"location": "Tokyo"}),
)
)
]
assert_tool_calls(extracted_info.tool_calls, expected_tool_calls)
assert extracted_info.content is None
def test_extract_tool_calls_multiple_tools(
openai_tool_parser,
harmony_encoding,
):
convo = Conversation.from_messages(
[
Message.from_role_and_content(
Role.USER, "What is the weather in Tokyo based on where I'm at?"
),
Message.from_role_and_content(
Role.ASSISTANT,
'User asks: "What is the weather in Tokyo?" based on their location. We need to use get_current_weather tool and get_user_location tool.', # noqa: E501
).with_channel("analysis"),
Message.from_role_and_content(Role.ASSISTANT, '{"location": "Tokyo"}')
.with_channel("commentary")
.with_recipient("functions.get_current_weather")
.with_content_type("json"),
Message.from_role_and_content(Role.ASSISTANT, '{"location": "Tokyo"}')
.with_channel("commentary")
.with_recipient("functions.get_user_location")
.with_content_type("json"),
Message.from_role_and_content(Role.ASSISTANT, '{"location": "Tokyo"}')
.with_channel("commentary")
.with_recipient("functions.no_content_type"),
Message.from_role_and_content(Role.ASSISTANT, "foo")
.with_channel("commentary")
.with_recipient("functions.not_json_no_content_type"),
Message.from_role_and_content(Role.ASSISTANT, "{}")
.with_channel("commentary")
.with_recipient("functions.empty_args")
.with_content_type("json"),
Message.from_role_and_content(Role.ASSISTANT, "")
.with_channel("commentary")
.with_recipient("functions.no_args")
.with_content_type("json"),
]
)
token_ids = harmony_encoding.render_conversation_for_completion(
convo,
Role.ASSISTANT,
)
extracted_info = openai_tool_parser.extract_tool_calls(
"",
request=None,
token_ids=token_ids,
)
assert extracted_info.tools_called
expected_tool_calls = [
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({"location": "Tokyo"}),
)
),
ToolCall(
function=FunctionCall(
name="get_user_location",
arguments=json.dumps({"location": "Tokyo"}),
)
),
ToolCall(
function=FunctionCall(
name="no_content_type",
arguments=json.dumps({"location": "Tokyo"}),
)
),
ToolCall(
function=FunctionCall(
name="not_json_no_content_type",
arguments="foo",
)
),
ToolCall(
function=FunctionCall(
name="empty_args",
arguments=json.dumps({}),
)
),
ToolCall(
function=FunctionCall(
name="no_args",
arguments="",
)
),
]
assert_tool_calls(extracted_info.tool_calls, expected_tool_calls)
assert extracted_info.content is None
def test_extract_tool_calls_with_content(
openai_tool_parser,
harmony_encoding,
):
final_content = "This tool call will get the weather."
convo = Conversation.from_messages(
[
Message.from_role_and_content(
Role.USER, "What is the weather in Tokyo based on where I'm at?"
),
Message.from_role_and_content(
Role.ASSISTANT,
'User asks: "What is the weather in Tokyo?" based on their location. We need to use get_current_weather tool and get_user_location tool.', # noqa: E501
).with_channel("analysis"),
Message.from_role_and_content(Role.ASSISTANT, '{"location": "Tokyo"}')
.with_channel("commentary")
.with_recipient("functions.get_current_weather")
.with_content_type("json"),
Message.from_role_and_content(Role.ASSISTANT, final_content).with_channel(
"final"
),
]
)
token_ids = harmony_encoding.render_conversation_for_completion(
convo,
Role.ASSISTANT,
)
extracted_info = openai_tool_parser.extract_tool_calls(
"",
request=None,
token_ids=token_ids,
)
assert extracted_info.tools_called
expected_tool_calls = [
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({"location": "Tokyo"}),
)
),
]
assert_tool_calls(extracted_info.tool_calls, expected_tool_calls)
assert extracted_info.content == final_content

View File

@@ -0,0 +1,976 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import Generator
import pytest
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
ChatCompletionToolsParam,
DeltaMessage,
FunctionCall,
ToolCall,
)
from vllm.tokenizers import TokenizerLike, get_tokenizer
from vllm.tokenizers.detokenizer_utils import detokenize_incrementally
from vllm.tool_parsers.qwen3coder_tool_parser import (
Qwen3CoderToolParser,
)
from vllm.tool_parsers.qwen3xml_tool_parser import Qwen3XMLToolParser
MODEL = "Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8"
@pytest.fixture(scope="module")
def qwen3_tokenizer():
return get_tokenizer(tokenizer_name=MODEL)
@pytest.fixture
def qwen3_tool_parser(qwen3_tokenizer):
return Qwen3CoderToolParser(qwen3_tokenizer)
@pytest.fixture
def qwen3_xml_tool_parser(qwen3_tokenizer):
return Qwen3XMLToolParser(qwen3_tokenizer)
@pytest.fixture(params=["xml"])
def qwen3_tool_parser_parametrized(qwen3_tool_parser, qwen3_xml_tool_parser, request):
"""Parameterized fixture that provides both parser types for testing"""
if request.param == "original":
return qwen3_tool_parser
else:
return qwen3_xml_tool_parser
@pytest.fixture
def sample_tools():
return [
ChatCompletionToolsParam(
type="function",
function={
"name": "get_current_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"properties": {
"city": {"type": "string", "description": "The city name"},
"state": {"type": "string", "description": "The state code"},
"unit": {"type": "string", "enum": ["fahrenheit", "celsius"]},
},
"required": ["city", "state"],
},
},
),
ChatCompletionToolsParam(
type="function",
function={
"name": "calculate_area",
"description": "Calculate area of a shape",
"parameters": {
"type": "object",
"properties": {
"shape": {"type": "string"},
"dimensions": {"type": "object"},
"precision": {"type": "integer"},
},
},
},
),
]
def assert_tool_calls(
actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall]
):
assert len(actual_tool_calls) == len(expected_tool_calls)
for actual_tool_call, expected_tool_call in zip(
actual_tool_calls, expected_tool_calls
):
# Qwen3 parser doesn't generate IDs during extraction
assert actual_tool_call.type == "function"
assert actual_tool_call.function.name == expected_tool_call.function.name
assert json.loads(actual_tool_call.function.arguments) == json.loads(
expected_tool_call.function.arguments
)
def stream_delta_message_generator(
qwen3_tool_parser,
qwen3_tokenizer: TokenizerLike,
model_output: str,
request: ChatCompletionRequest | None = None,
) -> Generator[DeltaMessage, None, None]:
all_token_ids = qwen3_tokenizer.encode(model_output, add_special_tokens=False)
previous_text = ""
previous_tokens = None
prefix_offset = 0
read_offset = 0
for i, delta_token in enumerate(all_token_ids):
delta_token_ids = [delta_token]
previous_token_ids = all_token_ids[:i]
current_token_ids = all_token_ids[: i + 1]
(new_tokens, delta_text, new_prefix_offset, new_read_offset) = (
detokenize_incrementally(
tokenizer=qwen3_tokenizer,
all_input_ids=current_token_ids,
prev_tokens=previous_tokens,
prefix_offset=prefix_offset,
read_offset=read_offset,
skip_special_tokens=False,
spaces_between_special_tokens=True,
)
)
current_text = previous_text + delta_text
delta_message = qwen3_tool_parser.extract_tool_calls_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
delta_token_ids,
request=request,
)
if delta_message:
yield delta_message
previous_text = current_text
previous_tokens = (
previous_tokens + new_tokens if previous_tokens else new_tokens
)
prefix_offset = new_prefix_offset
read_offset = new_read_offset
def test_extract_tool_calls_no_tools(qwen3_tool_parser_parametrized):
model_output = "This is a test response without any tool calls"
extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls(
model_output, request=None
) # type: ignore[arg-type]
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
assert extracted_tool_calls.content == model_output
@pytest.mark.parametrize(
ids=[
"single_tool",
"single_tool_with_content",
"single_tool_multiline_param",
"parallel_tools",
"tool_with_typed_params",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
(
"""<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=state>
TX
</parameter>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call>""",
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
),
)
)
],
None,
),
(
"""Sure! Let me check the weather for you.<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=state>
TX
</parameter>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call>""",
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
),
)
)
],
"Sure! Let me check the weather for you.",
),
(
"""<tool_call>
<function=calculate_area>
<parameter=shape>
rectangle
</parameter>
<parameter=dimensions>
{"width": 10,
"height": 20}
</parameter>
<parameter=precision>
2
</parameter>
</function>
</tool_call>""",
[
ToolCall(
function=FunctionCall(
name="calculate_area",
arguments=json.dumps(
{
"shape": "rectangle",
"dimensions": {"width": 10, "height": 20},
"precision": 2,
}
),
)
)
],
None,
),
(
"""<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=state>
TX
</parameter>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call>
<tool_call>
<function=get_current_weather>
<parameter=city>
Orlando
</parameter>
<parameter=state>
FL
</parameter>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call>""",
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
),
)
),
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "Orlando", "state": "FL", "unit": "fahrenheit"}
),
)
),
],
None,
),
(
"""Let me calculate that area for you.<tool_call>
<function=calculate_area>
<parameter=shape>
circle
</parameter>
<parameter=dimensions>
{"radius": 15.5}
</parameter>
<parameter=precision>
3
</parameter>
</function>
</tool_call>""",
[
ToolCall(
function=FunctionCall(
name="calculate_area",
arguments=json.dumps(
{
"shape": "circle",
"dimensions": {"radius": 15.5},
"precision": 3,
}
),
)
)
],
"Let me calculate that area for you.",
),
],
)
def test_extract_tool_calls(
qwen3_tool_parser_parametrized,
sample_tools,
model_output,
expected_tool_calls,
expected_content,
):
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls(
model_output, request=request
)
assert extracted_tool_calls.tools_called
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
assert extracted_tool_calls.content == expected_content
def test_extract_tool_calls_fallback_no_tags(
qwen3_tool_parser_parametrized, sample_tools
):
"""Test fallback parsing when XML tags are missing"""
model_output = """<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=state>
TX
</parameter>
</function>"""
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls(
model_output, request=request
)
assert extracted_tool_calls.tools_called
assert len(extracted_tool_calls.tool_calls) == 1
assert extracted_tool_calls.tool_calls[0].function.name == "get_current_weather"
def test_extract_tool_calls_type_conversion(qwen3_tool_parser_parametrized):
"""Test parameter type conversion based on tool schema"""
tools = [
ChatCompletionToolsParam(
type="function",
function={
"name": "test_types",
"parameters": {
"type": "object",
"properties": {
"int_param": {"type": "integer"},
"float_param": {"type": "float"},
"bool_param": {"type": "boolean"},
"str_param": {"type": "string"},
"obj_param": {"type": "object"},
},
},
},
)
]
model_output = """<tool_call>
<function=test_types>
<parameter=int_param>
42
</parameter>
<parameter=float_param>
3.14
</parameter>
<parameter=bool_param>
true
</parameter>
<parameter=str_param>
hello world
</parameter>
<parameter=obj_param>
{"key": "value"}
</parameter>
</function>
</tool_call>"""
request = ChatCompletionRequest(model=MODEL, messages=[], tools=tools)
extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls(
model_output, request=request
)
args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments)
assert args["int_param"] == 42
assert args["float_param"] == 3.14
assert args["bool_param"] is True
assert args["str_param"] == "hello world"
assert args["obj_param"] == {"key": "value"}
@pytest.mark.parametrize(
ids=[
"no_tools",
"single_tool",
"single_tool_with_content",
"single_tool_multiline_param",
"parallel_tools",
"tool_with_typed_params", # Added this test case
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
("This is a test without tools", [], "This is a test without tools"),
(
"""<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=state>
TX
</parameter>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call>""",
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
),
)
)
],
None,
),
(
"""Sure! Let me check the weather for you.<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=state>
TX
</parameter>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call>""",
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
),
)
)
],
"Sure! Let me check the weather for you.",
),
(
"""<tool_call>
<function=calculate_area>
<parameter=shape>
rectangle
</parameter>
<parameter=dimensions>
{"width": 10,
"height": 20}
</parameter>
<parameter=precision>
2
</parameter>
</function>
</tool_call>""",
[
ToolCall(
function=FunctionCall(
name="calculate_area",
arguments=json.dumps(
{
"shape": "rectangle",
"dimensions": {"width": 10, "height": 20},
"precision": 2,
}
),
)
)
],
None,
),
(
"""<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=state>
TX
</parameter>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call>
<tool_call>
<function=get_current_weather>
<parameter=city>
Orlando
</parameter>
<parameter=state>
FL
</parameter>
<parameter=unit>
celsius
</parameter>
</function>
</tool_call>""",
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
),
)
),
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "Orlando", "state": "FL", "unit": "celsius"}
),
)
),
],
None,
),
# Added tool_with_typed_params test case
(
"""Let me calculate that area for you.<tool_call>
<function=calculate_area>
<parameter=shape>
circle
</parameter>
<parameter=dimensions>
{"radius": 15.5}
</parameter>
<parameter=precision>
3
</parameter>
</function>
</tool_call>""",
[
ToolCall(
function=FunctionCall(
name="calculate_area",
arguments=json.dumps(
{
"shape": "circle",
"dimensions": {"radius": 15.5},
"precision": 3,
}
),
)
)
],
"Let me calculate that area for you.",
),
],
)
def test_extract_tool_calls_streaming(
qwen3_tool_parser_parametrized,
qwen3_tokenizer,
sample_tools,
model_output,
expected_tool_calls,
expected_content,
):
"""Test incremental streaming behavior including typed parameters"""
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
other_content = ""
tool_states = {} # Track state per tool index
for delta_message in stream_delta_message_generator(
qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output, request
):
# role should never be streamed from tool parser
assert not delta_message.role
if delta_message.content:
other_content += delta_message.content
if delta_message.tool_calls:
for tool_call in delta_message.tool_calls:
idx = tool_call.index
# Initialize state for new tool
if idx not in tool_states:
tool_states[idx] = {
"id": None,
"name": None,
"arguments": "",
"type": None,
}
# First chunk should have id, name, and type
if tool_call.id:
tool_states[idx]["id"] = tool_call.id
if tool_call.type:
assert tool_call.type == "function"
tool_states[idx]["type"] = tool_call.type
if tool_call.function:
if tool_call.function.name:
# Should only be set once
assert tool_states[idx]["name"] is None
tool_states[idx]["name"] = tool_call.function.name
if tool_call.function.arguments is not None:
# Accumulate arguments incrementally
tool_states[idx]["arguments"] += tool_call.function.arguments
# Verify final content
assert other_content == (expected_content or "") # Handle None case
# Verify we got all expected tool calls
assert len(tool_states) == len(expected_tool_calls)
assert len(qwen3_tool_parser_parametrized.prev_tool_call_arr) == len(
expected_tool_calls
)
# Verify each tool call
for idx, expected_tool in enumerate(expected_tool_calls):
state = tool_states[idx]
assert state["id"] is not None
assert state["type"] == "function"
assert state["name"] == expected_tool.function.name
# Parse accumulated arguments
arguments_str = state["arguments"]
assert arguments_str is not None
actual_args = json.loads(arguments_str)
expected_args = json.loads(expected_tool.function.arguments)
assert actual_args == expected_args
def test_extract_tool_calls_missing_closing_parameter_tag(
qwen3_tool_parser_parametrized, sample_tools
):
"""Test handling of missing closing </parameter> tag"""
# Using get_current_weather from sample_tools but with malformed XML
model_output = """Let me check the weather for you:
<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
<parameter=state>
TX
</parameter>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call>"""
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls(
model_output, request=request
)
# The parser should handle the malformed XML gracefully
assert extracted_tool_calls.tools_called
assert len(extracted_tool_calls.tool_calls) == 1
# Verify the function name is correct
assert extracted_tool_calls.tool_calls[0].function.name == "get_current_weather"
# Verify the arguments are parsed despite the missing closing tag
args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments)
assert "city" in args
assert args["city"] == "Dallas"
assert args["state"] == "TX"
assert args["unit"] == "fahrenheit"
# Check that content before the tool call is preserved
assert "Let me check the weather for you:" in extracted_tool_calls.content
def test_extract_tool_calls_streaming_missing_closing_tag(
qwen3_tool_parser_parametrized, qwen3_tokenizer, sample_tools
):
"""Test streaming with missing closing </parameter> tag"""
# Using get_current_weather from sample_tools but with malformed XML
model_output = """Let me check the weather for you:
<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
<parameter=state>
TX
</parameter>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call>"""
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
other_content = ""
tool_states = {}
for delta_message in stream_delta_message_generator(
qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output, request
):
if delta_message.content:
other_content += delta_message.content
if delta_message.tool_calls:
for tool_call in delta_message.tool_calls:
idx = tool_call.index
if idx not in tool_states:
tool_states[idx] = {
"id": None,
"name": None,
"arguments": "",
"type": None,
}
if tool_call.id:
tool_states[idx]["id"] = tool_call.id
if tool_call.type:
assert tool_call.type == "function"
tool_states[idx]["type"] = tool_call.type
if tool_call.function:
if tool_call.function.name:
tool_states[idx]["name"] = tool_call.function.name
if tool_call.function.arguments is not None:
tool_states[idx]["arguments"] += tool_call.function.arguments
# Verify content was streamed
assert "Let me check the weather for you:" in other_content
# Verify we got the tool call
assert len(tool_states) == 1
assert len(qwen3_tool_parser_parametrized.prev_tool_call_arr) == 1
state = tool_states[0]
assert state["id"] is not None
assert state["type"] == "function"
assert state["name"] == "get_current_weather"
# Verify arguments were parsed correctly despite missing closing tag
assert state["arguments"] is not None
args = json.loads(state["arguments"])
assert args["city"] == "Dallas"
assert args["state"] == "TX"
assert args["unit"] == "fahrenheit"
def test_extract_tool_calls_streaming_incremental(
qwen3_tool_parser_parametrized, qwen3_tokenizer, sample_tools
):
"""Test that streaming is truly incremental"""
model_output = """I'll check the weather.<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=state>
TX
</parameter>
</function>
</tool_call>"""
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
chunks = []
for delta_message in stream_delta_message_generator(
qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output, request
):
chunks.append(delta_message)
# Should have multiple chunks
assert len(chunks) > 3
# First chunk(s) should be content
assert chunks[0].content is not None
assert chunks[0].tool_calls is None or chunks[0].tool_calls == []
# Should have a chunk with tool header (id, name, type)
header_found = False
for chunk in chunks:
if chunk.tool_calls and chunk.tool_calls[0].id:
header_found = True
assert chunk.tool_calls[0].function.name == "get_current_weather"
assert chunk.tool_calls[0].type == "function"
# Empty initially
assert chunk.tool_calls[0].function.arguments == ""
break
assert header_found
# Should have chunks with incremental arguments
arg_chunks = []
for chunk in chunks:
if chunk.tool_calls and chunk.tool_calls[0].function.arguments:
arg_chunks.append(chunk.tool_calls[0].function.arguments)
# Arguments should be streamed incrementally
assert len(arg_chunks) > 1
# Concatenated arguments should form valid JSON
full_args = "".join(arg_chunks)
parsed_args = json.loads(full_args)
assert parsed_args["city"] == "Dallas"
assert parsed_args["state"] == "TX"
def test_extract_tool_calls_complex_type_with_single_quote(
qwen3_tool_parser_parametrized,
):
"""Test parameter type conversion based on tool schema"""
tools = [
ChatCompletionToolsParam(
type="function",
function={
"name": "test_types",
"parameters": {
"type": "object",
"properties": {
"int_param": {"type": "integer"},
"float_param": {"type": "float"},
"bool_param": {"type": "boolean"},
"str_param": {"type": "string"},
"obj_param": {"type": "object"},
},
},
},
)
]
model_output = """<tool_call>
<function=test_types>
<parameter=obj_param>
{'key': 'value'}
</parameter>
</function>
</tool_call>"""
request = ChatCompletionRequest(model=MODEL, messages=[], tools=tools)
extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls(
model_output, request=request
)
args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments)
assert args["obj_param"] == {"key": "value"}
def test_extract_tool_calls_streaming_missing_opening_tag(
qwen3_tool_parser_parametrized, qwen3_tokenizer, sample_tools
):
"""Test streaming with missing opening <tool_call> tag
This tests that the streaming parser correctly handles
tool calls that start directly with <function=...>
"""
model_output = """I'll check the weather for you.
<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=state>
TX
</parameter>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call>"""
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
other_content = ""
tool_states = {}
for delta_message in stream_delta_message_generator(
qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output, request
):
if delta_message.content:
other_content += delta_message.content
if delta_message.tool_calls:
for tool_call in delta_message.tool_calls:
idx = tool_call.index
if idx not in tool_states:
tool_states[idx] = {
"id": None,
"name": None,
"arguments": "",
"type": None,
}
if tool_call.id:
tool_states[idx]["id"] = tool_call.id
if tool_call.type:
assert tool_call.type == "function"
tool_states[idx]["type"] = tool_call.type
if tool_call.function:
if tool_call.function.name:
tool_states[idx]["name"] = tool_call.function.name
if tool_call.function.arguments is not None:
tool_states[idx]["arguments"] += tool_call.function.arguments
# Verify content was streamed
assert "I'll check the weather for you." in other_content
# Verify we got the tool call
assert len(tool_states) == 1
assert len(qwen3_tool_parser_parametrized.prev_tool_call_arr) == 1
state = tool_states[0]
assert state["id"] is not None
assert state["type"] == "function"
assert state["name"] == "get_current_weather"
# Verify arguments were parsed correctly despite missing opening tag
assert state["arguments"] is not None
args = json.loads(state["arguments"])
assert args["city"] == "Dallas"
assert args["state"] == "TX"
assert args["unit"] == "fahrenheit"

View File

@@ -0,0 +1,495 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
import json
from collections.abc import Generator
import pytest
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
ChatCompletionToolsParam,
DeltaMessage,
FunctionCall,
ToolCall,
)
from vllm.tokenizers import TokenizerLike, get_tokenizer
from vllm.tokenizers.detokenizer_utils import detokenize_incrementally
from vllm.tool_parsers.seed_oss_tool_parser import SeedOssToolParser
# Use a common model that is likely to be available
MODEL = "ByteDance-Seed/Seed-OSS-36B-Instruct"
@pytest.fixture(scope="module")
def seed_oss_tokenizer():
return get_tokenizer(tokenizer_name=MODEL, trust_remote_code=True)
@pytest.fixture
def seed_oss_tool_parser(seed_oss_tokenizer):
return SeedOssToolParser(seed_oss_tokenizer)
@pytest.fixture
def sample_tools():
return [
ChatCompletionToolsParam(
type="function",
function={
"name": "get_weather",
"description": "Get current temperature for a given location.",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "City and country e.g. Bogotá, Colombia",
},
"unit": {
"type": "string",
"description": "this is the unit of temperature",
},
},
"required": ["location"],
"additionalProperties": False,
},
"returns": {
"type": "object",
"properties": {
"temperature": {
"type": "number",
"description": "temperature in celsius",
}
},
"required": ["temperature"],
"additionalProperties": False,
},
"strict": True,
},
),
]
def assert_tool_calls(
actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall]
):
assert len(actual_tool_calls) == len(expected_tool_calls)
for actual_tool_call, expected_tool_call in zip(
actual_tool_calls, expected_tool_calls
):
# Seed-OSS tool call will not generate id
assert actual_tool_call.type == "function"
assert actual_tool_call.function == expected_tool_call.function
assert actual_tool_call.function.name == expected_tool_call.function.name
assert (
actual_tool_call.function.arguments == expected_tool_call.function.arguments
)
def test_extract_tool_calls_no_tools(seed_oss_tool_parser):
model_output = "This is a test response without any tool calls"
extracted_tool_calls = seed_oss_tool_parser.extract_tool_calls(
model_output, request=None
) # type: ignore[arg-type]
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
assert extracted_tool_calls.content == model_output
@pytest.mark.parametrize(
ids=[
"tool_call_0_thinking_budget",
"tool_call_512_thinkg_budget",
"tool_call_unlimited_thinking_budget",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
(
"""<seed:tool_call>\n<function=get_weather>\n"""
"""<parameter=location>Barcelona, Spain</parameter>\n</function>\n</seed:tool_call>""",
[
ToolCall(
function=FunctionCall(
name="get_weather",
arguments=json.dumps(
{
"location": "Barcelona, Spain",
},
),
),
type="function",
)
],
None,
),
(
"""<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """
"""question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """
"""there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """
"""check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """
"""optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """
"""country). \n<seed:cot_budget_reflect>I have used 131 tokens, and there are 381 tokens remaining for use."""
"""</seed:cot_budget_reflect>\n Since the unit isn\'t specified, the function will default to Celsius, which """
"""is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """
"""the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """
"""user\'s input has a space, but the function might accept either; to be safe, using the standard format """
"""with a comma).\n<seed:cot_budget_reflect>I have used 257 tokens, and there are 255 tokens remaining for """
"""use.</seed:cot_budget_reflect>\n The unit parameter can be omitted since it\'s optional.</seed:think>\n"""
"""<seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, Spain</parameter>\n</function>"""
"""\n</seed:tool_call>""",
[
ToolCall(
function=FunctionCall(
name="get_weather",
arguments=json.dumps(
{
"location": "Barcelona, Spain",
},
),
),
type="function",
)
],
"""<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """
"""question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """
"""there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """
"""check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """
"""optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """
"""country). \n<seed:cot_budget_reflect>I have used 131 tokens, and there are 381 tokens remaining for use."""
"""</seed:cot_budget_reflect>\n Since the unit isn\'t specified, the function will default to Celsius, which """
"""is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """
"""the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """
"""user\'s input has a space, but the function might accept either; to be safe, using the standard format """
"""with a comma).\n<seed:cot_budget_reflect>I have used 257 tokens, and there are 255 tokens remaining for """
"""use.</seed:cot_budget_reflect>\n The unit parameter can be omitted since it\'s optional.</seed:think>\n""",
),
(
"""<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """
"""First, I need to remember the function I can use: get_weather. The function requires a """
"""location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """
"""the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """
"""let me check the function docstring again. Oh, the function says unit is optional, and """
"""returns temperature in Celsius. So I should call get_weather with location "Barcelona, """
"""Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """
"""The format is <seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, """
"""Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>. """
"""Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """
"""of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """
"""it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """
"""call should be as above. Then wait for the result to come back and tell the user the """
"""temperature in Celsius.</seed:think><seed:tool_call>\n<function=get_weather>\n<parameter=location>"""
"""Barcelona, Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>""",
[
ToolCall(
function=FunctionCall(
name="get_weather",
arguments=json.dumps(
{
"location": "Barcelona, Spain",
"unit": "celsius",
},
),
),
type="function",
)
],
"""<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """
"""First, I need to remember the function I can use: get_weather. The function requires a """
"""location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """
"""the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """
"""let me check the function docstring again. Oh, the function says unit is optional, and """
"""returns temperature in Celsius. So I should call get_weather with location "Barcelona, """
"""Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """
"""The format is <seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, """
"""Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>. """
"""Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """
"""of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """
"""it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """
"""call should be as above. Then wait for the result to come back and tell the user the """
"""temperature in Celsius.</seed:think>""",
),
],
)
def test_extract_tool_calls(
seed_oss_tool_parser,
sample_tools,
model_output,
expected_tool_calls,
expected_content,
):
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
extracted_tool_calls = seed_oss_tool_parser.extract_tool_calls(
model_output, request=request
) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
assert extracted_tool_calls.content == expected_content
def test_streaming_tool_calls_no_tools(seed_oss_tool_parser):
model_output = "This is a test response without any tool calls"
result = seed_oss_tool_parser.extract_tool_calls_streaming(
previous_text="his is a test response",
current_text=model_output,
delta_text=" without any tool calls.",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=None,
)
# Should return the delta text as content
assert result is not None
assert hasattr(result, "content")
assert result.content == " without any tool calls."
def stream_delta_message_generator(
seed_oss_tool_parser: SeedOssToolParser,
seed_oss_tokenizer: TokenizerLike,
model_output: str,
request: ChatCompletionRequest | None = None,
) -> Generator[DeltaMessage, None, None]:
all_token_ids = seed_oss_tokenizer.encode(model_output, add_special_tokens=False)
previous_text = ""
previous_tokens = None
prefix_offset = 0
read_offset = 0
for i, delta_token in enumerate(all_token_ids):
delta_token_ids = [delta_token]
previous_token_ids = all_token_ids[:i]
current_token_ids = all_token_ids[: i + 1]
(new_tokens, delta_text, new_prefix_offset, new_read_offset) = (
detokenize_incrementally(
tokenizer=seed_oss_tokenizer,
all_input_ids=current_token_ids,
prev_tokens=previous_tokens,
prefix_offset=prefix_offset,
read_offset=read_offset,
skip_special_tokens=False,
spaces_between_special_tokens=True,
)
)
current_text = previous_text + delta_text
delta_message = seed_oss_tool_parser.extract_tool_calls_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
delta_token_ids,
request=request,
)
if delta_message:
yield delta_message
previous_text = current_text
previous_tokens = (
previous_tokens + new_tokens if previous_tokens else new_tokens
)
prefix_offset = new_prefix_offset
read_offset = new_read_offset
@pytest.mark.parametrize(
ids=[
"tool_call_0_thinking_budget",
"tool_call_512_thinkg_budget",
"tool_call_unlimited_thinking_budget",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
(
"""<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n"""
"""The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n"""
"""<seed:tool_call>\n<function=get_weather>\n"""
"""<parameter=location>Barcelona, Spain</parameter>\n</function>\n</seed:tool_call>""",
[
ToolCall(
function=FunctionCall(
name="get_weather",
arguments=json.dumps(
{
"location": "Barcelona, Spain",
},
),
),
type="function",
)
],
"""<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n"""
"""The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n""",
),
(
"""<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """
"""question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """
"""there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """
"""check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """
"""optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """
"""country). \n<seed:cot_budget_reflect>I have used 131 tokens, and there are 381 tokens remaining for use."""
"""</seed:cot_budget_reflect>\n Since the unit isn\'t specified, the function will default to Celsius, which """
"""is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """
"""the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """
"""user\'s input has a space, but the function might accept either; to be safe, using the standard format """
"""with a comma).\n<seed:cot_budget_reflect>I have used 257 tokens, and there are 255 tokens remaining for """
"""use.</seed:cot_budget_reflect>\n The unit parameter can be omitted since it\'s optional.</seed:think>\n"""
"""<seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, Spain</parameter>\n</function>"""
"""\n</seed:tool_call>""",
[
ToolCall(
function=FunctionCall(
name="get_weather",
arguments=json.dumps(
{
"location": "Barcelona, Spain",
},
),
),
type="function",
)
],
"""<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """
"""question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """
"""there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """
"""check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """
"""optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """
"""country). \n<seed:cot_budget_reflect>I have used 131 tokens, and there are 381 tokens remaining for use."""
"""</seed:cot_budget_reflect>\n Since the unit isn\'t specified, the function will default to Celsius, which """
"""is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """
"""the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """
"""user\'s input has a space, but the function might accept either; to be safe, using the standard format """
"""with a comma).\n<seed:cot_budget_reflect>I have used 257 tokens, and there are 255 tokens remaining for """
"""use.</seed:cot_budget_reflect>\n The unit parameter can be omitted since it\'s optional.</seed:think>\n""",
),
(
"""<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """
"""First, I need to remember the function I can use: get_weather. The function requires a """
"""location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """
"""the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """
"""let me check the function docstring again. Oh, the function says unit is optional, and """
"""returns temperature in Celsius. So I should call get_weather with location "Barcelona, """
"""Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """
"""The format is <seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, """
"""Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>. """
"""Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """
"""of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """
"""it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """
"""call should be as above. Then wait for the result to come back and tell the user the """
"""temperature in Celsius.</seed:think><seed:tool_call>\n<function=get_weather>\n<parameter=location>"""
"""Barcelona, Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>""",
[
ToolCall(
function=FunctionCall(
name="get_weather",
arguments=json.dumps(
{
"location": "Barcelona, Spain",
"unit": "celsius",
},
),
),
type="function",
)
],
"""<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """
"""First, I need to remember the function I can use: get_weather. The function requires a """
"""location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """
"""the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """
"""let me check the function docstring again. Oh, the function says unit is optional, and """
"""returns temperature in Celsius. So I should call get_weather with location "Barcelona, """
"""Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """
"""The format is <seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, """
"""Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>. """
"""Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """
"""of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """
"""it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """
"""call should be as above. Then wait for the result to come back and tell the user the """
"""temperature in Celsius.</seed:think>""",
),
],
)
def test_streaming_tool_calls(
seed_oss_tool_parser,
seed_oss_tokenizer,
sample_tools,
model_output,
expected_tool_calls,
expected_content,
):
"""Test incremental streaming behavior"""
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
other_content = ""
tool_states = {} # Track state per tool index
for delta_message in stream_delta_message_generator(
seed_oss_tool_parser, seed_oss_tokenizer, model_output, request
):
# role should never be streamed from tool parser
assert not delta_message.role
if delta_message.content:
other_content += delta_message.content
if delta_message.tool_calls:
for tool_call in delta_message.tool_calls:
idx = tool_call.index
# Initialize state for new tool
if idx not in tool_states:
tool_states[idx] = {
"id": None,
"name": None,
"arguments": "",
"type": None,
}
# First chunk should have id, name, and type
if tool_call.id:
tool_states[idx]["id"] = tool_call.id
if tool_call.type:
assert tool_call.type == "function"
tool_states[idx]["type"] = tool_call.type
if tool_call.function:
if tool_call.function.name:
# Should only be set once
assert tool_states[idx]["name"] is None
tool_states[idx]["name"] = tool_call.function.name
if tool_call.function.arguments is not None:
# Accumulate arguments incrementally
tool_states[idx]["arguments"] += tool_call.function.arguments
# Verify final content
assert other_content == expected_content
# Verify we got all expected tool calls
assert len(tool_states) == len(expected_tool_calls)
# Verify each tool call
for idx, expected_tool in enumerate(expected_tool_calls):
state = tool_states[idx]
assert state["id"] is not None
assert state["type"] == "function"
assert state["name"] == expected_tool.function.name
# Parse accumulated arguments
arguments_str = state["arguments"]
assert arguments_str is not None
actual_args = json.loads(arguments_str)
expected_args = json.loads(expected_tool.function.arguments)
assert actual_args == expected_args

View File

@@ -0,0 +1,534 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import Generator
import pytest
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaMessage,
FunctionCall,
ToolCall,
)
from vllm.tokenizers import TokenizerLike, get_tokenizer
from vllm.tokenizers.detokenizer_utils import detokenize_incrementally
from vllm.tool_parsers.xlam_tool_parser import xLAMToolParser
# Use a common model that is likely to be available
MODEL = "Salesforce/Llama-xLAM-2-8B-fc-r"
@pytest.fixture(scope="module")
def xlam_tokenizer():
return get_tokenizer(tokenizer_name=MODEL)
@pytest.fixture
def xlam_tool_parser(xlam_tokenizer):
return xLAMToolParser(xlam_tokenizer)
def assert_tool_calls(
actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall]
):
assert len(actual_tool_calls) == len(expected_tool_calls)
for actual_tool_call, expected_tool_call in zip(
actual_tool_calls, expected_tool_calls
):
assert isinstance(actual_tool_call.id, str)
assert len(actual_tool_call.id) > 16
assert actual_tool_call.type == "function"
assert actual_tool_call.function == expected_tool_call.function
def stream_delta_message_generator(
xlam_tool_parser: xLAMToolParser,
xlam_tokenizer: TokenizerLike,
model_output: str,
request: ChatCompletionRequest | None = None,
) -> Generator[DeltaMessage, None, None]:
all_token_ids = xlam_tokenizer.encode(model_output, add_special_tokens=False)
previous_text = ""
previous_tokens = None
prefix_offset = 0
read_offset = 0
for i, delta_token in enumerate(all_token_ids):
delta_token_ids = [delta_token]
previous_token_ids = all_token_ids[:i]
current_token_ids = all_token_ids[: i + 1]
(new_tokens, delta_text, new_prefix_offset, new_read_offset) = (
detokenize_incrementally(
tokenizer=xlam_tokenizer,
all_input_ids=current_token_ids,
prev_tokens=previous_tokens,
prefix_offset=prefix_offset,
read_offset=read_offset,
skip_special_tokens=False,
spaces_between_special_tokens=True,
)
)
current_text = previous_text + delta_text
delta_message = xlam_tool_parser.extract_tool_calls_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
delta_token_ids,
request=request,
)
if delta_message:
yield delta_message
previous_text = current_text
previous_tokens = (
previous_tokens + new_tokens if previous_tokens else new_tokens
)
prefix_offset = new_prefix_offset
read_offset = new_read_offset
def test_extract_tool_calls_no_tools(xlam_tool_parser):
model_output = "This is a test"
extracted_tool_calls = xlam_tool_parser.extract_tool_calls(
model_output, request=None
) # type: ignore[arg-type]
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
assert extracted_tool_calls.content == model_output
@pytest.mark.parametrize(
ids=[
"parallel_tool_calls",
"single_tool_with_think_tag",
"single_tool_with_json_code_block",
"single_tool_with_tool_calls_tag",
"single_tool_with_tool_call_xml_tags",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
(
"""[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}, {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}
),
)
),
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{
"city": "Orlando",
"state": "FL",
"unit": "fahrenheit",
}
),
)
),
],
None,
),
(
"""<think>I'll help you with that.</think>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}
),
)
)
],
"<think>I'll help you with that.</think>",
),
(
"""I'll help you with that.\n```json\n[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]\n```""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}
),
)
)
],
"I'll help you with that.",
),
(
"""I'll check the weather for you.[TOOL_CALLS][{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}
),
)
)
],
"I'll check the weather for you.",
),
(
"""I'll help you check the weather.<tool_call>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]</tool_call>""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}
),
)
)
],
"I'll help you check the weather.",
),
],
)
def test_extract_tool_calls(
xlam_tool_parser, model_output, expected_tool_calls, expected_content
):
extracted_tool_calls = xlam_tool_parser.extract_tool_calls(
model_output, request=None
) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
assert extracted_tool_calls.content == expected_content
@pytest.mark.parametrize(
ids=["list_structured_tool_call"],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
(
"""[{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{
"city": "Seattle",
"state": "WA",
"unit": "celsius",
}
),
)
)
],
None,
),
],
)
def test_extract_tool_calls_list_structure(
xlam_tool_parser, model_output, expected_tool_calls, expected_content
):
"""Test extraction of tool calls when the model outputs a list-structured tool call.""" # noqa: E501
extracted_tool_calls = xlam_tool_parser.extract_tool_calls(
model_output, request=None
) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
assert extracted_tool_calls.content == expected_content
# Test for preprocess_model_output method
def test_preprocess_model_output(xlam_tool_parser):
# Test with list structure
model_output = (
"""[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]""" # noqa: E501
)
content, potential_tool_calls = xlam_tool_parser.preprocess_model_output(
model_output
)
assert content is None
assert potential_tool_calls == model_output
# Test with thinking tag
model_output = """<think>I'll help you with that.</think>[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]""" # noqa: E501
content, potential_tool_calls = xlam_tool_parser.preprocess_model_output(
model_output
)
assert content == "<think>I'll help you with that.</think>"
assert (
potential_tool_calls
== '[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]'
)
# Test with JSON code block
model_output = """I'll help you with that.
```json
[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]
```"""
content, potential_tool_calls = xlam_tool_parser.preprocess_model_output(
model_output
)
assert content == "I'll help you with that."
assert "get_current_weather" in potential_tool_calls
# Test with no tool calls
model_output = """I'll help you with that."""
content, potential_tool_calls = xlam_tool_parser.preprocess_model_output(
model_output
)
assert content == model_output
assert potential_tool_calls is None
# Simulate streaming to test extract_tool_calls_streaming
def test_streaming_with_list_structure(xlam_tool_parser):
# Reset streaming state
xlam_tool_parser.prev_tool_calls = []
xlam_tool_parser.current_tools_sent = []
xlam_tool_parser.streamed_args = []
xlam_tool_parser.current_tool_id = -1
# Simulate receiving a message with list structure
current_text = (
"""[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]""" # noqa: E501
)
# First call to set up the tool
xlam_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text=current_text,
delta_text="]",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=None,
)
# Make sure the tool is set up correctly
assert xlam_tool_parser.current_tool_id >= 0, "Tool index should be initialized"
# Manually set up the state for sending the tool name
xlam_tool_parser.current_tools_sent = [False]
# Call to send the function name
result = xlam_tool_parser.extract_tool_calls_streaming(
previous_text=current_text,
current_text=current_text,
delta_text="",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=None,
)
# Check that we get a result with the proper tool call
if result is not None:
assert hasattr(result, "tool_calls")
assert len(result.tool_calls) == 1
assert result.tool_calls[0].function.name == "get_current_weather"
@pytest.mark.parametrize(
ids=[
"parallel_tool_calls",
"single_tool_with_think_tag",
"single_tool_with_json_code_block",
"single_tool_with_tool_calls_tag",
"single_tool_with_tool_call_xml_tags",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
(
"""[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}, {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}
),
)
),
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{
"city": "Orlando",
"state": "FL",
"unit": "fahrenheit",
}
),
)
),
],
"",
),
(
"""<think>I'll help you with that.</think>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}
),
)
)
],
"<think>I'll help you with that.</think>",
),
(
"""```json\n[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]\n```""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}
),
)
)
],
"",
),
(
"""[TOOL_CALLS][{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}
),
)
)
],
"",
),
(
"""I can help with that.<tool_call>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]</tool_call>""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}
),
)
)
],
"I can help with that.",
),
],
)
def test_extract_tool_calls_streaming_incremental(
xlam_tool_parser,
xlam_tokenizer,
model_output,
expected_tool_calls,
expected_content,
):
"""Verify the XLAM Parser streaming behavior by verifying each chunk is as expected.""" # noqa: E501
request = ChatCompletionRequest(model=MODEL, messages=[], tools=[])
chunks = []
for delta_message in stream_delta_message_generator(
xlam_tool_parser, xlam_tokenizer, model_output, request
):
chunks.append(delta_message)
# Should have multiple chunks
assert len(chunks) >= 3
# Should have a chunk with tool header (id, name, type) for the first tool call # noqa: E501
header_found = False
expected_first_tool = expected_tool_calls[0]
for chunk in chunks:
if chunk.tool_calls and chunk.tool_calls[0].id:
header_found = True
assert (
chunk.tool_calls[0].function.name == expected_first_tool.function.name
)
assert chunk.tool_calls[0].type == "function"
# Arguments may be empty initially or None
if chunk.tool_calls[0].function.arguments is not None:
# If present, should be empty string initially
assert chunk.tool_calls[0].function.arguments == ""
break
assert header_found
# Should have chunks with incremental arguments
arg_chunks = []
for chunk in chunks:
if (
chunk.tool_calls
and chunk.tool_calls[0].function.arguments
and chunk.tool_calls[0].function.arguments != ""
and chunk.tool_calls[0].index
== 0 # Only collect arguments from the first tool call
):
arg_chunks.append(chunk.tool_calls[0].function.arguments)
# Arguments should be streamed incrementally
assert len(arg_chunks) > 1
# Concatenated arguments should form valid JSON for the first tool call
full_args = "".join(arg_chunks)
parsed_args = json.loads(full_args)
expected_args = json.loads(expected_first_tool.function.arguments)
assert parsed_args == expected_args