Sync from v0.13
This commit is contained in:
0
tests/tool_parsers/__init__.py
Normal file
0
tests/tool_parsers/__init__.py
Normal file
61
tests/tool_parsers/test_deepseekv31_tool_parser.py
Normal file
61
tests/tool_parsers/test_deepseekv31_tool_parser.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
from vllm.tool_parsers.deepseekv31_tool_parser import (
|
||||
DeepSeekV31ToolParser,
|
||||
)
|
||||
|
||||
MODEL = "deepseek-ai/DeepSeek-V3.1"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def deepseekv31_tokenizer():
|
||||
return get_tokenizer(tokenizer_name=MODEL)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def parser(deepseekv31_tokenizer):
|
||||
return DeepSeekV31ToolParser(deepseekv31_tokenizer)
|
||||
|
||||
|
||||
def test_extract_tool_calls_with_tool(parser):
|
||||
model_output = (
|
||||
"normal text"
|
||||
+ "<|tool▁calls▁begin|>"
|
||||
+ '<|tool▁call▁begin|>foo<|tool▁sep|>{"x":1}<|tool▁call▁end|>'
|
||||
+ "<|tool▁calls▁end|>"
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, None)
|
||||
assert result.tools_called
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].function.name == "foo"
|
||||
assert result.tool_calls[0].function.arguments == '{"x":1}'
|
||||
assert result.content == "normal text"
|
||||
|
||||
|
||||
def test_extract_tool_calls_with_multiple_tools(parser):
|
||||
model_output = (
|
||||
"some prefix text"
|
||||
+ "<|tool▁calls▁begin|>"
|
||||
+ '<|tool▁call▁begin|>foo<|tool▁sep|>{"x":1}<|tool▁call▁end|>'
|
||||
+ '<|tool▁call▁begin|>bar<|tool▁sep|>{"y":2}<|tool▁call▁end|>'
|
||||
+ "<|tool▁calls▁end|>"
|
||||
+ " some suffix text"
|
||||
)
|
||||
|
||||
result = parser.extract_tool_calls(model_output, None)
|
||||
|
||||
assert result.tools_called
|
||||
assert len(result.tool_calls) == 2
|
||||
|
||||
assert result.tool_calls[0].function.name == "foo"
|
||||
assert result.tool_calls[0].function.arguments == '{"x":1}'
|
||||
|
||||
assert result.tool_calls[1].function.name == "bar"
|
||||
assert result.tool_calls[1].function.arguments == '{"y":2}'
|
||||
|
||||
# prefix is content
|
||||
assert result.content == "some prefix text"
|
||||
359
tests/tool_parsers/test_ernie45_moe_tool_parser.py
Normal file
359
tests/tool_parsers/test_ernie45_moe_tool_parser.py
Normal file
@@ -0,0 +1,359 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: E501
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaMessage,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.tokenizers import TokenizerLike, get_tokenizer
|
||||
from vllm.tokenizers.detokenizer_utils import detokenize_incrementally
|
||||
from vllm.tool_parsers.ernie45_tool_parser import Ernie45ToolParser
|
||||
|
||||
# Use a common model that is likely to be available
|
||||
MODEL = "baidu/ERNIE-4.5-21B-A3B-Thinking"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def ernie45_tokenizer():
|
||||
return get_tokenizer(tokenizer_name=MODEL, trust_remote_code=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ernie45_tool_parser(ernie45_tokenizer):
|
||||
return Ernie45ToolParser(ernie45_tokenizer)
|
||||
|
||||
|
||||
def assert_tool_calls(
|
||||
actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall]
|
||||
):
|
||||
assert len(actual_tool_calls) == len(expected_tool_calls)
|
||||
|
||||
for actual_tool_call, expected_tool_call in zip(
|
||||
actual_tool_calls, expected_tool_calls
|
||||
):
|
||||
assert isinstance(actual_tool_call.id, str)
|
||||
assert len(actual_tool_call.id) > 0
|
||||
|
||||
assert actual_tool_call.type == "function"
|
||||
assert actual_tool_call.function.name == expected_tool_call.function.name
|
||||
# Compare arguments as JSON objects to handle formatting differences
|
||||
actual_args = json.loads(actual_tool_call.function.arguments)
|
||||
expected_args = json.loads(expected_tool_call.function.arguments)
|
||||
assert actual_args == expected_args
|
||||
|
||||
|
||||
def test_extract_tool_calls_no_tools(ernie45_tool_parser):
|
||||
model_output = "This is a test"
|
||||
extracted_tool_calls = ernie45_tool_parser.extract_tool_calls(
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
assert not extracted_tool_calls.tools_called
|
||||
assert extracted_tool_calls.tool_calls == []
|
||||
assert extracted_tool_calls.content == model_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"single_tool_call",
|
||||
"multiple_tool_calls",
|
||||
"tool_call_with_content_before",
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
(
|
||||
"""<tool_call>
|
||||
{"name": "get_current_temperature", "arguments": {"location": "Beijing"}}
|
||||
</tool_call>
|
||||
""",
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_temperature",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"location": "Beijing",
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
"""<tool_call>
|
||||
{"name": "get_current_temperature", "arguments": {"location": "Beijing"}}
|
||||
</tool_call>
|
||||
<tool_call>
|
||||
{"name": "get_temperature_unit", "arguments": {"location": "Guangzhou", "unit": "c"}}
|
||||
</tool_call>
|
||||
""",
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_temperature",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"location": "Beijing",
|
||||
}
|
||||
),
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_temperature_unit",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"location": "Guangzhou",
|
||||
"unit": "c",
|
||||
}
|
||||
),
|
||||
)
|
||||
),
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
"""I need to call two tools to handle these two issues separately.
|
||||
</think>
|
||||
|
||||
<tool_call>
|
||||
{"name": "get_current_temperature", "arguments": {"location": "Beijing"}}
|
||||
</tool_call>
|
||||
<tool_call>
|
||||
{"name": "get_temperature_unit", "arguments": {"location": "Guangzhou", "unit": "c"}}
|
||||
</tool_call>
|
||||
""",
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_temperature",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"location": "Beijing",
|
||||
}
|
||||
),
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_temperature_unit",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"location": "Guangzhou",
|
||||
"unit": "c",
|
||||
}
|
||||
),
|
||||
)
|
||||
),
|
||||
],
|
||||
"I need to call two tools to handle these two issues separately.\n</think>",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls(
|
||||
ernie45_tool_parser, model_output, expected_tool_calls, expected_content
|
||||
):
|
||||
extracted_tool_calls = ernie45_tool_parser.extract_tool_calls(
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
assert extracted_tool_calls.tools_called
|
||||
|
||||
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
|
||||
|
||||
assert extracted_tool_calls.content == expected_content
|
||||
|
||||
|
||||
def stream_delta_message_generator(
|
||||
ernie45_tool_parser: Ernie45ToolParser,
|
||||
ernie45_tokenizer: TokenizerLike,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest | None = None,
|
||||
) -> Generator[DeltaMessage, None, None]:
|
||||
all_token_ids = ernie45_tokenizer.encode(model_output, add_special_tokens=False)
|
||||
|
||||
previous_text = ""
|
||||
previous_tokens = None
|
||||
prefix_offset = 0
|
||||
read_offset = 0
|
||||
for i, delta_token in enumerate(all_token_ids):
|
||||
delta_token_ids = [delta_token]
|
||||
previous_token_ids = all_token_ids[:i]
|
||||
current_token_ids = all_token_ids[: i + 1]
|
||||
|
||||
(new_tokens, delta_text, new_prefix_offset, new_read_offset) = (
|
||||
detokenize_incrementally(
|
||||
tokenizer=ernie45_tokenizer,
|
||||
all_input_ids=current_token_ids,
|
||||
prev_tokens=previous_tokens,
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
skip_special_tokens=False,
|
||||
spaces_between_special_tokens=True,
|
||||
)
|
||||
)
|
||||
|
||||
current_text = previous_text + delta_text
|
||||
|
||||
delta_message = ernie45_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text,
|
||||
current_text,
|
||||
delta_text,
|
||||
previous_token_ids,
|
||||
current_token_ids,
|
||||
delta_token_ids,
|
||||
request=request,
|
||||
)
|
||||
if delta_message:
|
||||
yield delta_message
|
||||
|
||||
previous_text = current_text
|
||||
previous_tokens = (
|
||||
previous_tokens + new_tokens if previous_tokens else new_tokens
|
||||
)
|
||||
prefix_offset = new_prefix_offset
|
||||
read_offset = new_read_offset
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"single_tool_call",
|
||||
"multiple_tool_calls",
|
||||
"tool_call_with_content_before",
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
(
|
||||
"""<tool_call>
|
||||
{"name": "get_current_temperature", "arguments": {"location": "Beijing"}}
|
||||
</tool_call>
|
||||
""",
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_temperature",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"location": "Beijing",
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
"""<tool_call>
|
||||
{"name": "get_current_temperature", "arguments": {"location": "Beijing"}}
|
||||
</tool_call>
|
||||
<tool_call>
|
||||
{"name": "get_temperature_unit", "arguments": {"location": "Guangzhou", "unit": "c"}}
|
||||
</tool_call>
|
||||
""",
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_temperature",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"location": "Beijing",
|
||||
}
|
||||
),
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_temperature_unit",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"location": "Guangzhou",
|
||||
"unit": "c",
|
||||
}
|
||||
),
|
||||
)
|
||||
),
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
"""I need to call two tools to handle these two issues separately.
|
||||
</think>
|
||||
|
||||
<tool_call>
|
||||
{"name": "get_current_temperature", "arguments": {"location": "Beijing"}}
|
||||
</tool_call>
|
||||
<tool_call>
|
||||
{"name": "get_temperature_unit", "arguments": {"location": "Guangzhou", "unit": "c"}}
|
||||
</tool_call>
|
||||
""",
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_temperature",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"location": "Beijing",
|
||||
}
|
||||
),
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_temperature_unit",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"location": "Guangzhou",
|
||||
"unit": "c",
|
||||
}
|
||||
),
|
||||
)
|
||||
),
|
||||
],
|
||||
"I need to call two tools to handle these two issues separately.\n</think>",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls_streaming_incremental(
|
||||
ernie45_tool_parser,
|
||||
ernie45_tokenizer,
|
||||
model_output,
|
||||
expected_tool_calls,
|
||||
expected_content,
|
||||
):
|
||||
"""Verify the Ernie45 Parser streaming behavior by verifying each chunk is as expected.""" # noqa: E501
|
||||
request = ChatCompletionRequest(model=MODEL, messages=[], tools=[])
|
||||
|
||||
tool_calls_dict = {}
|
||||
for delta_message in stream_delta_message_generator(
|
||||
ernie45_tool_parser, ernie45_tokenizer, model_output, request
|
||||
):
|
||||
if (
|
||||
delta_message.role is None
|
||||
and delta_message.content is None
|
||||
and delta_message.reasoning is None
|
||||
and len(delta_message.tool_calls) == 0
|
||||
):
|
||||
continue
|
||||
tool_calls = delta_message.tool_calls
|
||||
for tool_call_chunk in tool_calls:
|
||||
index = tool_call_chunk.index
|
||||
if index not in tool_calls_dict:
|
||||
if tool_call_chunk.function.arguments is None:
|
||||
tool_call_chunk.function.arguments = ""
|
||||
tool_calls_dict[index] = tool_call_chunk
|
||||
else:
|
||||
tool_calls_dict[
|
||||
index
|
||||
].function.arguments += tool_call_chunk.function.arguments
|
||||
actual_tool_calls = list(tool_calls_dict.values())
|
||||
|
||||
assert len(actual_tool_calls) > 0
|
||||
# check tool call format
|
||||
assert_tool_calls(actual_tool_calls, expected_tool_calls)
|
||||
449
tests/tool_parsers/test_glm4_moe_tool_parser.py
Normal file
449
tests/tool_parsers/test_glm4_moe_tool_parser.py
Normal file
@@ -0,0 +1,449 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: E501
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
from vllm.tool_parsers.glm4_moe_tool_parser import (
|
||||
Glm4MoeModelToolParser,
|
||||
)
|
||||
|
||||
pytest.skip("skip glm4_moe parser test", allow_module_level=True)
|
||||
# Use a common model that is likely to be available
|
||||
MODEL = "zai-org/GLM-4.5"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def glm4_moe_tokenizer():
|
||||
return get_tokenizer(tokenizer_name=MODEL)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def glm4_moe_tool_parser(glm4_moe_tokenizer):
|
||||
return Glm4MoeModelToolParser(glm4_moe_tokenizer)
|
||||
|
||||
|
||||
def assert_tool_calls(
|
||||
actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall]
|
||||
):
|
||||
assert len(actual_tool_calls) == len(expected_tool_calls)
|
||||
|
||||
for actual_tool_call, expected_tool_call in zip(
|
||||
actual_tool_calls, expected_tool_calls
|
||||
):
|
||||
assert isinstance(actual_tool_call.id, str)
|
||||
assert len(actual_tool_call.id) > 0
|
||||
|
||||
assert actual_tool_call.type == "function"
|
||||
assert actual_tool_call.function.name == expected_tool_call.function.name
|
||||
# Compare arguments as JSON objects to handle formatting differences
|
||||
actual_args = json.loads(actual_tool_call.function.arguments)
|
||||
expected_args = json.loads(expected_tool_call.function.arguments)
|
||||
assert actual_args == expected_args
|
||||
|
||||
|
||||
def test_extract_tool_calls_no_tools(glm4_moe_tool_parser):
|
||||
model_output = "This is a test"
|
||||
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
assert not extracted_tool_calls.tools_called
|
||||
assert extracted_tool_calls.tool_calls == []
|
||||
assert extracted_tool_calls.content == model_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"single_tool_call",
|
||||
"multiple_tool_calls",
|
||||
"tool_call_with_content_before",
|
||||
"tool_call_with_mixed_args",
|
||||
"tool_call_with_chinese_content",
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
(
|
||||
"""<tool_call>get_current_weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Dallas</arg_value>
|
||||
<arg_key>state</arg_key>
|
||||
<arg_value>TX</arg_value>
|
||||
<arg_key>unit</arg_key>
|
||||
<arg_value>fahrenheit</arg_value>
|
||||
</tool_call>""",
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
"""<tool_call>get_current_weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Dallas</arg_value>
|
||||
<arg_key>state</arg_key>
|
||||
<arg_value>TX</arg_value>
|
||||
<arg_key>unit</arg_key>
|
||||
<arg_value>fahrenheit</arg_value>
|
||||
</tool_call>
|
||||
<tool_call>get_current_weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Orlando</arg_value>
|
||||
<arg_key>state</arg_key>
|
||||
<arg_value>FL</arg_value>
|
||||
<arg_key>unit</arg_key>
|
||||
<arg_value>fahrenheit</arg_value>
|
||||
</tool_call>""",
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}
|
||||
),
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Orlando",
|
||||
"state": "FL",
|
||||
"unit": "fahrenheit",
|
||||
}
|
||||
),
|
||||
)
|
||||
),
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
"""I'll help you check the weather. <tool_call>get_current_weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Seattle</arg_value>
|
||||
<arg_key>state</arg_key>
|
||||
<arg_value>WA</arg_value>
|
||||
<arg_key>unit</arg_key>
|
||||
<arg_value>celsius</arg_value>
|
||||
</tool_call>""",
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Seattle",
|
||||
"state": "WA",
|
||||
"unit": "celsius",
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"I'll help you check the weather.",
|
||||
),
|
||||
(
|
||||
"""<tool_call>get_current_weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>New York</arg_value>
|
||||
<arg_key>state</arg_key>
|
||||
<arg_value>NY</arg_value>
|
||||
<arg_key>unit</arg_key>
|
||||
<arg_value>celsius</arg_value>
|
||||
</tool_call>""",
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "New York",
|
||||
"state": "NY",
|
||||
"unit": "celsius",
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
"""I will help you get the weather.<tool_call>get_weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Beijing</arg_value>
|
||||
<arg_key>date</arg_key>
|
||||
<arg_value>2025-08-01</arg_value>
|
||||
</tool_call>""",
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Beijing",
|
||||
"date": "2025-08-01",
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"I will help you get the weather.",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls(
|
||||
glm4_moe_tool_parser, model_output, expected_tool_calls, expected_content
|
||||
):
|
||||
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
assert extracted_tool_calls.tools_called
|
||||
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
|
||||
|
||||
assert extracted_tool_calls.content == expected_content
|
||||
|
||||
|
||||
def test_extract_tool_calls_with_thinking_tags(glm4_moe_tool_parser):
|
||||
"""Test tool extraction when thinking tags are present."""
|
||||
model_output = """<think>I want to get the weather.</think>
|
||||
|
||||
I will help you get the weather.
|
||||
<tool_call>get_weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Beijing</arg_value>
|
||||
<arg_key>date</arg_key>
|
||||
<arg_value>2025-08-01</arg_value>
|
||||
</tool_call>"""
|
||||
|
||||
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
|
||||
assert extracted_tool_calls.tools_called
|
||||
assert len(extracted_tool_calls.tool_calls) == 1
|
||||
assert extracted_tool_calls.tool_calls[0].function.name == "get_weather"
|
||||
|
||||
expected_content = """<think>I want to get the weather.</think>
|
||||
|
||||
I will help you get the weather."""
|
||||
assert extracted_tool_calls.content == expected_content
|
||||
|
||||
|
||||
def test_extract_tool_calls_malformed_xml(glm4_moe_tool_parser):
|
||||
"""Test that malformed XML is handled gracefully."""
|
||||
model_output = """<tool_call>get_weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Seattle</arg_value>
|
||||
<arg_key>incomplete_arg
|
||||
<arg_value>value</arg_value>
|
||||
</tool_call>"""
|
||||
|
||||
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
|
||||
# Should handle malformed XML gracefully
|
||||
# The parser should either extract what it can or return no tool calls
|
||||
# depending on how robust we want the parsing to be
|
||||
assert isinstance(extracted_tool_calls.tools_called, bool)
|
||||
assert isinstance(extracted_tool_calls.tool_calls, list)
|
||||
|
||||
|
||||
def test_extract_tool_calls_empty_arguments(glm4_moe_tool_parser):
|
||||
"""Test tool calls with no arguments."""
|
||||
model_output = """<tool_call>get_current_time
|
||||
</tool_call>"""
|
||||
|
||||
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
|
||||
assert extracted_tool_calls.tools_called
|
||||
assert len(extracted_tool_calls.tool_calls) == 1
|
||||
assert extracted_tool_calls.tool_calls[0].function.name == "get_current_time"
|
||||
# Empty arguments should result in empty JSON object
|
||||
assert extracted_tool_calls.tool_calls[0].function.arguments == "{}"
|
||||
|
||||
|
||||
def test_extract_tool_calls_mixed_content(glm4_moe_tool_parser):
|
||||
"""Test extraction with mixed content and multiple tool calls."""
|
||||
model_output = """I will help you get the weather info.
|
||||
|
||||
<tool_call>get_weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Beijing</arg_value>
|
||||
<arg_key>date</arg_key>
|
||||
<arg_value>2025-08-01</arg_value>
|
||||
</tool_call>
|
||||
|
||||
meaningwhile, I will also check the weather in Shanghai.
|
||||
|
||||
<tool_call>get_weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Shanghai</arg_value>
|
||||
<arg_key>date</arg_key>
|
||||
<arg_value>2025-08-01</arg_value>
|
||||
</tool_call>"""
|
||||
|
||||
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
|
||||
assert extracted_tool_calls.tools_called
|
||||
assert len(extracted_tool_calls.tool_calls) == 2
|
||||
|
||||
# Check first tool call
|
||||
assert extracted_tool_calls.tool_calls[0].function.name == "get_weather"
|
||||
args1 = json.loads(extracted_tool_calls.tool_calls[0].function.arguments)
|
||||
assert args1["city"] == "Beijing"
|
||||
assert args1["date"] == "2025-08-01"
|
||||
|
||||
# Check second tool call
|
||||
assert extracted_tool_calls.tool_calls[1].function.name == "get_weather"
|
||||
args2 = json.loads(extracted_tool_calls.tool_calls[1].function.arguments)
|
||||
assert args2["city"] == "Shanghai"
|
||||
assert args2["date"] == "2025-08-01"
|
||||
|
||||
# Content should be everything before the first tool call
|
||||
assert extracted_tool_calls.content == "I will help you get the weather info."
|
||||
|
||||
|
||||
def test_streaming_basic_functionality(glm4_moe_tool_parser):
|
||||
"""Test basic streaming functionality."""
|
||||
# Reset streaming state
|
||||
glm4_moe_tool_parser.current_tool_name_sent = False
|
||||
glm4_moe_tool_parser.prev_tool_call_arr = []
|
||||
glm4_moe_tool_parser.current_tool_id = -1
|
||||
glm4_moe_tool_parser.streamed_args_for_tool = []
|
||||
|
||||
# Test with a simple tool call
|
||||
current_text = """<tool_call>get_weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Beijing</arg_value>
|
||||
</tool_call>"""
|
||||
|
||||
# Mock token IDs for testing
|
||||
tool_call_start_id = glm4_moe_tool_parser.tool_call_start_token_id or 12345
|
||||
tool_call_end_id = glm4_moe_tool_parser.tool_call_end_token_id or 12346
|
||||
|
||||
result = glm4_moe_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="",
|
||||
current_text=current_text,
|
||||
delta_text="</tool_call>",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[tool_call_start_id, tool_call_end_id],
|
||||
delta_token_ids=[tool_call_end_id],
|
||||
request=None,
|
||||
)
|
||||
|
||||
# The result behavior depends on the streaming state
|
||||
# This test mainly ensures no exceptions are thrown
|
||||
assert result is None or hasattr(result, "tool_calls") or hasattr(result, "content")
|
||||
|
||||
|
||||
def test_streaming_no_tool_calls(glm4_moe_tool_parser):
|
||||
"""Test streaming when there are no tool calls."""
|
||||
current_text = "This is just regular text without any tool calls."
|
||||
|
||||
result = glm4_moe_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="This is just regular text",
|
||||
current_text=current_text,
|
||||
delta_text=" without any tool calls.",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=None,
|
||||
)
|
||||
|
||||
# Should return the delta text as content
|
||||
assert result is not None
|
||||
assert hasattr(result, "content")
|
||||
assert result.content == " without any tool calls."
|
||||
|
||||
|
||||
def test_streaming_with_content_before_tool_calls(glm4_moe_tool_parser):
|
||||
"""Test streaming when there's content before tool calls."""
|
||||
# Reset streaming state
|
||||
glm4_moe_tool_parser.current_tool_name_sent = False
|
||||
glm4_moe_tool_parser.prev_tool_call_arr = []
|
||||
glm4_moe_tool_parser.current_tool_id = -1
|
||||
glm4_moe_tool_parser.streamed_args_for_tool = []
|
||||
|
||||
current_text = "I will help you get the weather<tool_call>"
|
||||
|
||||
result = glm4_moe_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="I will help you",
|
||||
current_text=current_text,
|
||||
delta_text="get the weather.<tool_call>",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=None,
|
||||
)
|
||||
|
||||
# Should return content when no tool call tokens are detected
|
||||
assert result is not None
|
||||
assert hasattr(result, "content")
|
||||
assert result.content == "get the weather.<tool_call>"
|
||||
|
||||
|
||||
def test_extract_tool_calls_special_characters(glm4_moe_tool_parser):
|
||||
"""Test tool calls with special characters and unicode."""
|
||||
model_output = """<tool_call>send_message
|
||||
<arg_key>recipient</arg_key>
|
||||
<arg_value>Amy</arg_value>
|
||||
<arg_key>message</arg_key>
|
||||
<arg_value>It is a nice day</arg_value>
|
||||
<arg_key>priority</arg_key>
|
||||
<arg_value>high</arg_value>
|
||||
</tool_call>"""
|
||||
|
||||
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
|
||||
assert extracted_tool_calls.tools_called
|
||||
assert len(extracted_tool_calls.tool_calls) == 1
|
||||
assert extracted_tool_calls.tool_calls[0].function.name == "send_message"
|
||||
|
||||
args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments)
|
||||
assert args["recipient"] == "Amy"
|
||||
assert args["message"] == "It is a nice day"
|
||||
assert args["priority"] == "high"
|
||||
|
||||
|
||||
def test_extract_tool_calls_incomplete_tool_call(glm4_moe_tool_parser):
|
||||
"""Test incomplete tool calls (missing closing tag)."""
|
||||
model_output = """<tool_call>get_weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Beijing</arg_value>
|
||||
<arg_key>date</arg_key>
|
||||
<arg_value>2025-08-01</arg_value>"""
|
||||
|
||||
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
|
||||
# Incomplete tool calls should not be extracted
|
||||
assert not extracted_tool_calls.tools_called
|
||||
assert extracted_tool_calls.tool_calls == []
|
||||
assert extracted_tool_calls.content == model_output
|
||||
308
tests/tool_parsers/test_jamba_tool_parser.py
Normal file
308
tests/tool_parsers/test_jamba_tool_parser.py
Normal file
@@ -0,0 +1,308 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
|
||||
import partial_json_parser
|
||||
import pytest
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.openai.protocol import DeltaMessage, FunctionCall, ToolCall
|
||||
from vllm.tokenizers import TokenizerLike, get_tokenizer
|
||||
from vllm.tokenizers.detokenizer_utils import detokenize_incrementally
|
||||
from vllm.tool_parsers.jamba_tool_parser import JambaToolParser
|
||||
|
||||
MODEL = "ai21labs/Jamba-tiny-dev"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def jamba_tokenizer():
|
||||
return get_tokenizer(tokenizer_name=MODEL)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def jamba_tool_parser(jamba_tokenizer):
|
||||
return JambaToolParser(jamba_tokenizer)
|
||||
|
||||
|
||||
def assert_tool_calls(
|
||||
actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall]
|
||||
):
|
||||
assert len(actual_tool_calls) == len(expected_tool_calls)
|
||||
|
||||
for actual_tool_call, expected_tool_call in zip(
|
||||
actual_tool_calls, expected_tool_calls
|
||||
):
|
||||
assert isinstance(actual_tool_call.id, str)
|
||||
assert len(actual_tool_call.id) > 16
|
||||
|
||||
assert actual_tool_call.type == "function"
|
||||
assert actual_tool_call.function == expected_tool_call.function
|
||||
|
||||
|
||||
def stream_delta_message_generator(
|
||||
jamba_tool_parser: JambaToolParser,
|
||||
jamba_tokenizer: TokenizerLike,
|
||||
model_output: str,
|
||||
) -> Generator[DeltaMessage, None, None]:
|
||||
all_token_ids = jamba_tokenizer.encode(model_output, add_special_tokens=False)
|
||||
|
||||
previous_text = ""
|
||||
previous_tokens = None
|
||||
prefix_offset = 0
|
||||
read_offset = 0
|
||||
for i, delta_token in enumerate(all_token_ids):
|
||||
delta_token_ids = [delta_token]
|
||||
previous_token_ids = all_token_ids[:i]
|
||||
current_token_ids = all_token_ids[: i + 1]
|
||||
|
||||
(new_tokens, delta_text, new_prefix_offset, new_read_offset) = (
|
||||
detokenize_incrementally(
|
||||
tokenizer=jamba_tokenizer,
|
||||
all_input_ids=current_token_ids,
|
||||
prev_tokens=previous_tokens,
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
skip_special_tokens=False,
|
||||
spaces_between_special_tokens=True,
|
||||
)
|
||||
)
|
||||
|
||||
current_text = previous_text + delta_text
|
||||
|
||||
delta_message = jamba_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text,
|
||||
current_text,
|
||||
delta_text,
|
||||
previous_token_ids,
|
||||
current_token_ids,
|
||||
delta_token_ids,
|
||||
request=None, # type: ignore[arg-type]
|
||||
)
|
||||
if delta_message:
|
||||
yield delta_message
|
||||
|
||||
previous_text = current_text
|
||||
previous_tokens = (
|
||||
previous_tokens + new_tokens if previous_tokens else new_tokens
|
||||
)
|
||||
prefix_offset = new_prefix_offset
|
||||
read_offset = new_read_offset
|
||||
|
||||
|
||||
def test_extract_tool_calls_no_tools(jamba_tool_parser):
|
||||
model_output = "This is a test"
|
||||
extracted_tool_calls = jamba_tool_parser.extract_tool_calls(
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
assert not extracted_tool_calls.tools_called
|
||||
assert extracted_tool_calls.tool_calls == []
|
||||
assert extracted_tool_calls.content == model_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"single_tool",
|
||||
"single_tool_with_content",
|
||||
"parallel_tools",
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
(
|
||||
""" <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
""" Sure! let me call the tool for you.<tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
" Sure! let me call the tool for you.",
|
||||
),
|
||||
(
|
||||
""" <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}},\n {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}\n]</tool_calls>""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
|
||||
),
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "Orlando", "state": "FL", "unit": "fahrenheit"}
|
||||
),
|
||||
)
|
||||
),
|
||||
],
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls(
|
||||
jamba_tool_parser, model_output, expected_tool_calls, expected_content
|
||||
):
|
||||
extracted_tool_calls = jamba_tool_parser.extract_tool_calls(
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
assert extracted_tool_calls.tools_called
|
||||
|
||||
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
|
||||
|
||||
assert extracted_tool_calls.content == expected_content
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"no_tools",
|
||||
"single_tool",
|
||||
"single_tool_with_content",
|
||||
"parallel_tools",
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
("""This is a test""", [], """This is a test"""),
|
||||
(
|
||||
""" <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
" ",
|
||||
),
|
||||
(
|
||||
""" Sure! let me call the tool for you.<tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
" Sure! let me call the tool for you.",
|
||||
),
|
||||
(
|
||||
""" <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}},\n {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}\n]</tool_calls>""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
|
||||
),
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "Orlando", "state": "FL", "unit": "fahrenheit"}
|
||||
),
|
||||
)
|
||||
),
|
||||
],
|
||||
" ",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls_streaming(
|
||||
jamba_tool_parser,
|
||||
jamba_tokenizer,
|
||||
model_output,
|
||||
expected_tool_calls,
|
||||
expected_content,
|
||||
):
|
||||
other_content: str = ""
|
||||
function_names: list[str] = []
|
||||
function_args_strs: list[str] = []
|
||||
tool_call_idx: int = -1
|
||||
tool_call_ids: list[str | None] = []
|
||||
|
||||
for delta_message in stream_delta_message_generator(
|
||||
jamba_tool_parser, jamba_tokenizer, model_output
|
||||
):
|
||||
# role should never be streamed from tool parser
|
||||
assert not delta_message.role
|
||||
|
||||
if delta_message.content:
|
||||
other_content += delta_message.content
|
||||
|
||||
streamed_tool_calls = delta_message.tool_calls
|
||||
|
||||
if streamed_tool_calls and len(streamed_tool_calls) > 0:
|
||||
# make sure only one diff is present - correct even for parallel
|
||||
assert len(streamed_tool_calls) == 1
|
||||
tool_call = streamed_tool_calls[0]
|
||||
|
||||
# if a new tool is being called, set up empty arguments
|
||||
if tool_call.index != tool_call_idx:
|
||||
tool_call_idx = tool_call.index
|
||||
function_args_strs.append("")
|
||||
tool_call_ids.append(None)
|
||||
|
||||
# if a tool call ID is streamed, make sure one hasn't been already
|
||||
if tool_call.id and not tool_call_ids[tool_call.index]:
|
||||
tool_call_ids[tool_call.index] = tool_call.id
|
||||
|
||||
# if parts of the function start being streamed
|
||||
if tool_call.function:
|
||||
# if the function name is defined, set it. it should be streamed
|
||||
# IN ENTIRETY, exactly one time.
|
||||
if tool_call.function.name:
|
||||
assert isinstance(tool_call.function.name, str)
|
||||
function_names.append(tool_call.function.name)
|
||||
|
||||
if tool_call.function.arguments:
|
||||
# make sure they're a string and then add them to the list
|
||||
assert isinstance(tool_call.function.arguments, str)
|
||||
|
||||
function_args_strs[tool_call.index] += tool_call.function.arguments
|
||||
|
||||
assert other_content == expected_content
|
||||
|
||||
actual_tool_calls = [
|
||||
ToolCall(
|
||||
id=tool_call_id,
|
||||
function=FunctionCall(
|
||||
name=function_name,
|
||||
arguments=partial_json_parser.ensure_json(
|
||||
function_args_str, Allow.OBJ | Allow.STR
|
||||
),
|
||||
),
|
||||
)
|
||||
for tool_call_id, function_name, function_args_str in zip(
|
||||
tool_call_ids, function_names, function_args_strs
|
||||
)
|
||||
]
|
||||
assert_tool_calls(actual_tool_calls, expected_tool_calls)
|
||||
924
tests/tool_parsers/test_kimi_k2_tool_parser.py
Normal file
924
tests/tool_parsers/test_kimi_k2_tool_parser.py
Normal file
@@ -0,0 +1,924 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: E501
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
from vllm.tool_parsers.kimi_k2_tool_parser import KimiK2ToolParser
|
||||
|
||||
# Use a common model that is likely to be available
|
||||
MODEL = "moonshotai/Kimi-K2-Instruct"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def kimi_k2_tokenizer():
|
||||
return get_tokenizer(tokenizer_name=MODEL, trust_remote_code=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def kimi_k2_tool_parser(kimi_k2_tokenizer):
|
||||
return KimiK2ToolParser(kimi_k2_tokenizer)
|
||||
|
||||
|
||||
def assert_tool_calls(
|
||||
actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall]
|
||||
):
|
||||
assert len(actual_tool_calls) == len(expected_tool_calls)
|
||||
|
||||
for actual_tool_call, expected_tool_call in zip(
|
||||
actual_tool_calls, expected_tool_calls
|
||||
):
|
||||
assert actual_tool_call.type == "function"
|
||||
assert actual_tool_call.function == expected_tool_call.function
|
||||
|
||||
# assert tool call id format: should contain function name and numeric index
|
||||
# Format can be either "functions.func_name:0" or "func_name:0"
|
||||
assert actual_tool_call.id.split(":")[-1].isdigit()
|
||||
assert (
|
||||
actual_tool_call.id.split(":")[0].split(".")[-1]
|
||||
== expected_tool_call.function.name
|
||||
)
|
||||
|
||||
|
||||
def test_extract_tool_calls_no_tools(kimi_k2_tool_parser):
|
||||
model_output = "This is a test"
|
||||
extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls(
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
assert not extracted_tool_calls.tools_called
|
||||
assert extracted_tool_calls.tool_calls == []
|
||||
assert extracted_tool_calls.content == model_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"tool_call_with_content_before",
|
||||
"multi_tool_call_with_content_before",
|
||||
"concatenated_tool_calls_bug_fix",
|
||||
"three_concatenated_tool_calls",
|
||||
"mixed_spacing_tool_calls",
|
||||
"angle_brackets_in_json",
|
||||
"newlines_in_json",
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
(
|
||||
"""I'll help you check the weather. <|tool_calls_section_begin|> <|tool_call_begin|>
|
||||
functions.get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_calls_section_end|>""",
|
||||
[
|
||||
ToolCall(
|
||||
id="functions.get_weather:0",
|
||||
function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Beijing",
|
||||
},
|
||||
),
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
],
|
||||
"I'll help you check the weather. ",
|
||||
),
|
||||
(
|
||||
"""I'll help you check the weather. <|tool_calls_section_begin|> <|tool_call_begin|>
|
||||
functions.get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_call_begin|>
|
||||
functions.get_weather:1 <|tool_call_argument_begin|> {"city": "Shanghai"} <|tool_call_end|> <|tool_calls_section_end|>""",
|
||||
[
|
||||
ToolCall(
|
||||
id="functions.get_weather:0",
|
||||
function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Beijing",
|
||||
},
|
||||
),
|
||||
),
|
||||
type="function",
|
||||
),
|
||||
ToolCall(
|
||||
id="functions.get_weather:1",
|
||||
function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Shanghai",
|
||||
},
|
||||
),
|
||||
),
|
||||
type="function",
|
||||
),
|
||||
],
|
||||
"I'll help you check the weather. ",
|
||||
),
|
||||
(
|
||||
"""I'll get the weather and news for LA today. First, let me get the weather using Los Angeles coordinates, and then get the latest news. <|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"latitude": 34.0522, "longitude": -118.2437}<|tool_call_end|><|tool_call_begin|>functions.get_news:1<|tool_call_argument_begin|>{"content": "Los Angeles today"}<|tool_call_end|><|tool_calls_section_end|>""",
|
||||
[
|
||||
ToolCall(
|
||||
id="functions.get_weather:0",
|
||||
function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps(
|
||||
{"latitude": 34.0522, "longitude": -118.2437}
|
||||
),
|
||||
),
|
||||
type="function",
|
||||
),
|
||||
ToolCall(
|
||||
id="functions.get_news:1",
|
||||
function=FunctionCall(
|
||||
name="get_news",
|
||||
arguments=json.dumps({"content": "Los Angeles today"}),
|
||||
),
|
||||
type="function",
|
||||
),
|
||||
],
|
||||
"I'll get the weather and news for LA today. First, let me get the weather using Los Angeles coordinates, and then get the latest news. ",
|
||||
),
|
||||
(
|
||||
"""I'll help you with multiple tasks. <|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"city": "New York"}<|tool_call_end|><|tool_call_begin|>functions.get_news:1<|tool_call_argument_begin|>{"topic": "technology"}<|tool_call_end|><|tool_call_begin|>functions.send_email:2<|tool_call_argument_begin|>{"to": "user@example.com", "subject": "Daily Update"}<|tool_call_end|><|tool_calls_section_end|>""",
|
||||
[
|
||||
ToolCall(
|
||||
id="functions.get_weather:0",
|
||||
function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps({"city": "New York"}),
|
||||
),
|
||||
type="function",
|
||||
),
|
||||
ToolCall(
|
||||
id="functions.get_news:1",
|
||||
function=FunctionCall(
|
||||
name="get_news",
|
||||
arguments=json.dumps({"topic": "technology"}),
|
||||
),
|
||||
type="function",
|
||||
),
|
||||
ToolCall(
|
||||
id="functions.send_email:2",
|
||||
function=FunctionCall(
|
||||
name="send_email",
|
||||
arguments=json.dumps(
|
||||
{"to": "user@example.com", "subject": "Daily Update"}
|
||||
),
|
||||
),
|
||||
type="function",
|
||||
),
|
||||
],
|
||||
"I'll help you with multiple tasks. ",
|
||||
),
|
||||
(
|
||||
"""Mixed spacing test. <|tool_calls_section_begin|> <|tool_call_begin|> functions.test:0 <|tool_call_argument_begin|> {} <|tool_call_end|><|tool_call_begin|>functions.test2:1<|tool_call_argument_begin|>{}<|tool_call_end|> <|tool_calls_section_end|>""",
|
||||
[
|
||||
ToolCall(
|
||||
id="functions.test:0",
|
||||
function=FunctionCall(
|
||||
name="test",
|
||||
arguments=json.dumps({}),
|
||||
),
|
||||
type="function",
|
||||
),
|
||||
ToolCall(
|
||||
id="functions.test2:1",
|
||||
function=FunctionCall(
|
||||
name="test2",
|
||||
arguments=json.dumps({}),
|
||||
),
|
||||
type="function",
|
||||
),
|
||||
],
|
||||
"Mixed spacing test. ",
|
||||
),
|
||||
(
|
||||
"""I need to process HTML content. <|tool_calls_section_begin|><|tool_call_begin|>functions.process_html:0<|tool_call_argument_begin|>{"html": "<div>content</div>", "text": "normal text"}<|tool_call_end|><|tool_calls_section_end|>""",
|
||||
[
|
||||
ToolCall(
|
||||
id="functions.process_html:0",
|
||||
function=FunctionCall(
|
||||
name="process_html",
|
||||
arguments=json.dumps(
|
||||
{"html": "<div>content</div>", "text": "normal text"}
|
||||
),
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
],
|
||||
"I need to process HTML content. ",
|
||||
),
|
||||
(
|
||||
"""I need to process formatted JSON. <|tool_calls_section_begin|><|tool_call_begin|>functions.process_data:0<|tool_call_argument_begin|>{
|
||||
"name": "test",
|
||||
"value": 123,
|
||||
"nested": {
|
||||
"key": "value"
|
||||
}
|
||||
}<|tool_call_end|><|tool_calls_section_end|>""",
|
||||
[
|
||||
ToolCall(
|
||||
id="functions.process_data:0",
|
||||
function=FunctionCall(
|
||||
name="process_data",
|
||||
arguments=json.dumps(
|
||||
{"name": "test", "value": 123, "nested": {"key": "value"}},
|
||||
indent=2,
|
||||
),
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
],
|
||||
"I need to process formatted JSON. ",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls(
|
||||
kimi_k2_tool_parser, model_output, expected_tool_calls, expected_content
|
||||
):
|
||||
extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls(
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
assert extracted_tool_calls.tools_called
|
||||
|
||||
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
|
||||
|
||||
assert extracted_tool_calls.content == expected_content
|
||||
|
||||
|
||||
def test_extract_tool_calls_invalid_json(kimi_k2_tool_parser):
|
||||
"""we'll return every funcall result"""
|
||||
model_output = """I'll help you check the weather. <|tool_calls_section_begin|> <|tool_call_begin|>
|
||||
functions.invalid_get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing" <|tool_call_end|> <|tool_call_begin|>
|
||||
functions.valid_get_weather:1 <|tool_call_argument_begin|> {"city": "Shanghai"} <|tool_call_end|> <|tool_calls_section_end|>"""
|
||||
|
||||
extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls(
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
|
||||
assert extracted_tool_calls.tools_called
|
||||
# Should extract only the valid JSON tool calls
|
||||
assert len(extracted_tool_calls.tool_calls) == 2
|
||||
assert extracted_tool_calls.tool_calls[0].function.name == "invalid_get_weather"
|
||||
assert extracted_tool_calls.tool_calls[1].function.name == "valid_get_weather"
|
||||
|
||||
|
||||
def test_extract_tool_calls_invalid_funcall(kimi_k2_tool_parser):
|
||||
"""we'll return every funcall result"""
|
||||
model_output = """I'll help you check the weather. <|tool_calls_section_begin|> <|tool_call_begin|>
|
||||
functions.invalid_get_weather.0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_call_begin|>
|
||||
functions.valid_get_weather:1 <|tool_call_argument_begin|> {"city": "Shanghai"} <|tool_call_end|> <|tool_calls_section_end|>"""
|
||||
|
||||
extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls(
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
|
||||
assert extracted_tool_calls.tools_called
|
||||
# Should extract only the valid JSON tool calls
|
||||
assert len(extracted_tool_calls.tool_calls) == 1
|
||||
assert extracted_tool_calls.tool_calls[0].function.name == "valid_get_weather"
|
||||
|
||||
|
||||
def test_streaming_basic_functionality(kimi_k2_tool_parser):
|
||||
"""Test basic streaming functionality."""
|
||||
# Reset streaming state
|
||||
kimi_k2_tool_parser.current_tool_name_sent = False
|
||||
kimi_k2_tool_parser.prev_tool_call_arr = []
|
||||
kimi_k2_tool_parser.current_tool_id = -1
|
||||
kimi_k2_tool_parser.streamed_args_for_tool = []
|
||||
|
||||
# Test with a simple tool call
|
||||
current_text = """ check the weather. <|tool_calls_section_begin|> <|tool_call_begin|>
|
||||
functions.get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_calls_section_end|>"""
|
||||
|
||||
# First call should handle the initial setup
|
||||
result = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="I'll help you",
|
||||
current_text=current_text,
|
||||
delta_text="<|tool_calls_section_end|>",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=None,
|
||||
)
|
||||
|
||||
# The result might be None or contain tool call information
|
||||
# This depends on the internal state management
|
||||
if result is not None and hasattr(result, "tool_calls") and result.tool_calls:
|
||||
assert len(result.tool_calls) >= 0
|
||||
|
||||
|
||||
def test_streaming_no_tool_calls(kimi_k2_tool_parser):
|
||||
"""Test streaming when there are no tool calls."""
|
||||
current_text = "This is just regular text without any tool calls."
|
||||
|
||||
result = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="This is just regular text",
|
||||
current_text=current_text,
|
||||
delta_text=" without any tool calls.",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=None,
|
||||
)
|
||||
|
||||
# Should return the delta text as content
|
||||
assert result is not None
|
||||
assert hasattr(result, "content")
|
||||
assert result.content == " without any tool calls."
|
||||
|
||||
|
||||
def test_token_leak_between_section_and_tool_begin(kimi_k2_tool_parser):
|
||||
"""
|
||||
Test that text between <|tool_calls_section_begin|> and <|tool_call_begin|>
|
||||
is suppressed and does not leak into reasoning_delta.
|
||||
This is the main vulnerability being fixed.
|
||||
"""
|
||||
kimi_k2_tool_parser.reset_streaming_state()
|
||||
|
||||
# Get token IDs for the markers
|
||||
section_begin_token_id = kimi_k2_tool_parser.vocab.get(
|
||||
"<|tool_calls_section_begin|>"
|
||||
)
|
||||
tool_call_begin_token_id = kimi_k2_tool_parser.vocab.get("<|tool_call_begin|>")
|
||||
|
||||
# Simulate streaming sequence:
|
||||
# Delta 1: "I'll help you with that. "
|
||||
result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="",
|
||||
current_text="I'll help you with that. ",
|
||||
delta_text="I'll help you with that. ",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[1, 2, 3], # Regular tokens
|
||||
delta_token_ids=[1, 2, 3],
|
||||
request=None,
|
||||
)
|
||||
assert result1 is not None
|
||||
assert result1.content == "I'll help you with that. "
|
||||
|
||||
# Delta 2: "<|tool_calls_section_begin|>"
|
||||
prev_ids = [1, 2, 3]
|
||||
curr_ids = prev_ids + [section_begin_token_id]
|
||||
result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="I'll help you with that. ",
|
||||
current_text="I'll help you with that. <|tool_calls_section_begin|>",
|
||||
delta_text="<|tool_calls_section_begin|>",
|
||||
previous_token_ids=prev_ids,
|
||||
current_token_ids=curr_ids,
|
||||
delta_token_ids=[section_begin_token_id],
|
||||
request=None,
|
||||
)
|
||||
# Section marker should be stripped and suppressed
|
||||
assert result2 is None or (result2.content is None or result2.content == "")
|
||||
|
||||
# Delta 3: " spurious text or tokens " (THE LEAK SCENARIO)
|
||||
prev_ids = curr_ids
|
||||
curr_ids = curr_ids + [4, 5]
|
||||
result3 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="I'll help you with that. <|tool_calls_section_begin|>",
|
||||
current_text="I'll help you with that. <|tool_calls_section_begin|> spurious text ",
|
||||
delta_text=" spurious text ",
|
||||
previous_token_ids=prev_ids,
|
||||
current_token_ids=curr_ids,
|
||||
delta_token_ids=[4, 5],
|
||||
request=None,
|
||||
)
|
||||
# CRITICAL: This text should be suppressed, NOT returned as reasoning_delta
|
||||
assert result3 is None or (result3.content is None or result3.content == "")
|
||||
|
||||
# Delta 4: "<|tool_call_begin|>..."
|
||||
prev_ids = curr_ids
|
||||
curr_ids = curr_ids + [tool_call_begin_token_id]
|
||||
_result4 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="I'll help you with that. <|tool_calls_section_begin|> spurious text ",
|
||||
current_text="I'll help you with that. <|tool_calls_section_begin|> spurious text <|tool_call_begin|>",
|
||||
delta_text="<|tool_call_begin|>",
|
||||
previous_token_ids=prev_ids,
|
||||
current_token_ids=curr_ids,
|
||||
delta_token_ids=[tool_call_begin_token_id],
|
||||
request=None,
|
||||
)
|
||||
# Now we're in tool call mode, result depends on internal state
|
||||
# The key is that the spurious text from Delta 3 was not leaked
|
||||
|
||||
|
||||
def test_split_markers_across_deltas(kimi_k2_tool_parser):
|
||||
"""
|
||||
Test that markers split across delta chunks are correctly detected
|
||||
via the rolling buffer mechanism.
|
||||
"""
|
||||
kimi_k2_tool_parser.reset_streaming_state()
|
||||
|
||||
section_begin_token_id = kimi_k2_tool_parser.vocab.get(
|
||||
"<|tool_calls_section_begin|>"
|
||||
)
|
||||
|
||||
# Delta 1: "...reasoning<|tool_calls_sec"
|
||||
_result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="Some reasoning",
|
||||
current_text="Some reasoning<|tool_calls_sec",
|
||||
delta_text="<|tool_calls_sec",
|
||||
previous_token_ids=[1, 2],
|
||||
current_token_ids=[1, 2, 3], # Partial token
|
||||
delta_token_ids=[3],
|
||||
request=None,
|
||||
)
|
||||
# Partial token not recognized yet, might be buffered
|
||||
# Should return as content or None (depends on implementation)
|
||||
|
||||
# Delta 2: "tion_begin|> " (completes the marker)
|
||||
_result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="Some reasoning<|tool_calls_sec",
|
||||
current_text="Some reasoning<|tool_calls_section_begin|> ",
|
||||
delta_text="tion_begin|> ",
|
||||
previous_token_ids=[1, 2, 3],
|
||||
current_token_ids=[1, 2, section_begin_token_id, 4],
|
||||
delta_token_ids=[section_begin_token_id, 4],
|
||||
request=None,
|
||||
)
|
||||
# Now the complete marker should be detected via buffer
|
||||
# The parser should enter tool section mode
|
||||
assert kimi_k2_tool_parser.in_tool_section is True
|
||||
|
||||
|
||||
def test_marker_variants(kimi_k2_tool_parser):
|
||||
"""Test that both singular and plural marker variants are recognized."""
|
||||
kimi_k2_tool_parser.reset_streaming_state()
|
||||
|
||||
# Test singular variant: <|tool_call_section_begin|> (note: singular "call")
|
||||
singular_token_id = kimi_k2_tool_parser.vocab.get("<|tool_call_section_begin|>")
|
||||
|
||||
if singular_token_id is not None: # Only test if tokenizer supports it
|
||||
_result = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="Reasoning ",
|
||||
current_text="Reasoning <|tool_call_section_begin|>",
|
||||
delta_text="<|tool_call_section_begin|>",
|
||||
previous_token_ids=[1, 2],
|
||||
current_token_ids=[1, 2, singular_token_id],
|
||||
delta_token_ids=[singular_token_id],
|
||||
request=None,
|
||||
)
|
||||
# Should enter tool section mode with singular variant too
|
||||
assert kimi_k2_tool_parser.in_tool_section is True
|
||||
|
||||
|
||||
def test_reentry_to_reasoning_after_tool_section(kimi_k2_tool_parser):
|
||||
"""
|
||||
Test that after exiting a tool section with <|tool_calls_section_end|>,
|
||||
subsequent text is correctly returned as reasoning content.
|
||||
"""
|
||||
kimi_k2_tool_parser.reset_streaming_state()
|
||||
|
||||
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
|
||||
section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>")
|
||||
|
||||
# Enter tool section
|
||||
_result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="",
|
||||
current_text="<|tool_calls_section_begin|>",
|
||||
delta_text="<|tool_calls_section_begin|>",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[section_begin_id],
|
||||
delta_token_ids=[section_begin_id],
|
||||
request=None,
|
||||
)
|
||||
assert kimi_k2_tool_parser.in_tool_section is True
|
||||
|
||||
# Exit tool section
|
||||
_result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="<|tool_calls_section_begin|>",
|
||||
current_text="<|tool_calls_section_begin|><|tool_calls_section_end|>",
|
||||
delta_text="<|tool_calls_section_end|>",
|
||||
previous_token_ids=[section_begin_id],
|
||||
current_token_ids=[section_begin_id, section_end_id],
|
||||
delta_token_ids=[section_end_id],
|
||||
request=None,
|
||||
)
|
||||
assert kimi_k2_tool_parser.in_tool_section is False
|
||||
|
||||
# Subsequent reasoning text should be returned normally
|
||||
result3 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="<|tool_calls_section_begin|><|tool_calls_section_end|>",
|
||||
current_text="<|tool_calls_section_begin|><|tool_calls_section_end|> More reasoning",
|
||||
delta_text=" More reasoning",
|
||||
previous_token_ids=[section_begin_id, section_end_id],
|
||||
current_token_ids=[section_begin_id, section_end_id, 10, 11],
|
||||
delta_token_ids=[10, 11],
|
||||
request=None,
|
||||
)
|
||||
assert result3 is not None
|
||||
assert result3.content == " More reasoning"
|
||||
|
||||
|
||||
def test_empty_tool_section(kimi_k2_tool_parser):
|
||||
"""Test an empty tool section (begin immediately followed by end)."""
|
||||
kimi_k2_tool_parser.reset_streaming_state()
|
||||
|
||||
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
|
||||
section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>")
|
||||
|
||||
# Section begin
|
||||
_result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="Reasoning ",
|
||||
current_text="Reasoning <|tool_calls_section_begin|>",
|
||||
delta_text="<|tool_calls_section_begin|>",
|
||||
previous_token_ids=[1],
|
||||
current_token_ids=[1, section_begin_id],
|
||||
delta_token_ids=[section_begin_id],
|
||||
request=None,
|
||||
)
|
||||
|
||||
# Immediate section end
|
||||
_result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="Reasoning <|tool_calls_section_begin|>",
|
||||
current_text="Reasoning <|tool_calls_section_begin|><|tool_calls_section_end|>",
|
||||
delta_text="<|tool_calls_section_end|>",
|
||||
previous_token_ids=[1, section_begin_id],
|
||||
current_token_ids=[1, section_begin_id, section_end_id],
|
||||
delta_token_ids=[section_end_id],
|
||||
request=None,
|
||||
)
|
||||
# Should exit cleanly without errors
|
||||
assert kimi_k2_tool_parser.in_tool_section is False
|
||||
|
||||
|
||||
def test_malformed_tool_section_recovery(kimi_k2_tool_parser):
|
||||
"""
|
||||
Test that the parser recovers from a malformed tool section
|
||||
that never closes properly.
|
||||
"""
|
||||
kimi_k2_tool_parser.reset_streaming_state()
|
||||
|
||||
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
|
||||
|
||||
# Enter tool section
|
||||
_result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="",
|
||||
current_text="<|tool_calls_section_begin|>",
|
||||
delta_text="<|tool_calls_section_begin|>",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[section_begin_id],
|
||||
delta_token_ids=[section_begin_id],
|
||||
request=None,
|
||||
)
|
||||
assert kimi_k2_tool_parser.in_tool_section is True
|
||||
|
||||
# Simulate a lot of text without proper tool calls or section end
|
||||
# This should trigger the error recovery mechanism
|
||||
large_text = "x" * 10000 # Exceeds max_section_chars
|
||||
|
||||
result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="<|tool_calls_section_begin|>",
|
||||
current_text="<|tool_calls_section_begin|>" + large_text,
|
||||
delta_text=large_text,
|
||||
previous_token_ids=[section_begin_id],
|
||||
current_token_ids=[section_begin_id] + list(range(100, 100 + len(large_text))),
|
||||
delta_token_ids=list(range(100, 100 + len(large_text))),
|
||||
request=None,
|
||||
)
|
||||
|
||||
# Parser should have force-exited the tool section
|
||||
assert kimi_k2_tool_parser.in_tool_section is False
|
||||
# And returned the content as reasoning
|
||||
assert result2 is not None
|
||||
assert result2.content == large_text
|
||||
|
||||
|
||||
def test_state_reset(kimi_k2_tool_parser):
|
||||
"""Test that reset_streaming_state() properly clears all state."""
|
||||
# Put parser in a complex state
|
||||
kimi_k2_tool_parser.in_tool_section = True
|
||||
kimi_k2_tool_parser.token_buffer = "some buffer"
|
||||
kimi_k2_tool_parser.current_tool_id = 5
|
||||
kimi_k2_tool_parser.prev_tool_call_arr = [{"id": "test"}]
|
||||
kimi_k2_tool_parser.section_char_count = 1000
|
||||
|
||||
# Reset
|
||||
kimi_k2_tool_parser.reset_streaming_state()
|
||||
|
||||
# Verify all state is cleared
|
||||
assert kimi_k2_tool_parser.in_tool_section is False
|
||||
assert kimi_k2_tool_parser.token_buffer == ""
|
||||
assert kimi_k2_tool_parser.current_tool_id == -1
|
||||
assert kimi_k2_tool_parser.prev_tool_call_arr == []
|
||||
assert kimi_k2_tool_parser.section_char_count == 0
|
||||
assert kimi_k2_tool_parser.current_tool_name_sent is False
|
||||
assert kimi_k2_tool_parser.streamed_args_for_tool == []
|
||||
|
||||
|
||||
def test_section_begin_noise_tool_begin_same_chunk(kimi_k2_tool_parser):
|
||||
"""
|
||||
Test that begin→noise→tool_begin within the SAME chunk suppresses
|
||||
the noise text correctly (not just across chunks).
|
||||
"""
|
||||
kimi_k2_tool_parser.reset_streaming_state()
|
||||
|
||||
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
|
||||
tool_call_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_call_begin|>")
|
||||
|
||||
# Single delta containing: section_begin + spurious text + tool_call_begin
|
||||
combined_text = "<|tool_calls_section_begin|> noise text <|tool_call_begin|>"
|
||||
|
||||
result = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="Reasoning ",
|
||||
current_text="Reasoning " + combined_text,
|
||||
delta_text=combined_text,
|
||||
previous_token_ids=[1, 2],
|
||||
current_token_ids=[1, 2, section_begin_id, 3, 4, tool_call_begin_id],
|
||||
delta_token_ids=[section_begin_id, 3, 4, tool_call_begin_id],
|
||||
request=None,
|
||||
)
|
||||
|
||||
# The noise text should NOT leak into content
|
||||
# Result should either be None/empty or start tool call parsing
|
||||
if result is not None and result.content is not None:
|
||||
# If content is returned, it should not contain the noise
|
||||
assert "noise text" not in result.content
|
||||
assert result.content == "" or result.content.strip() == ""
|
||||
|
||||
|
||||
def test_stream_ends_without_section_end_marker(kimi_k2_tool_parser):
|
||||
"""
|
||||
Test that if the stream ends (EOF) without a proper section end marker,
|
||||
the parser doesn't leak text, doesn't crash, and resets state cleanly.
|
||||
"""
|
||||
kimi_k2_tool_parser.reset_streaming_state()
|
||||
|
||||
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
|
||||
|
||||
# Enter tool section
|
||||
_result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="",
|
||||
current_text="<|tool_calls_section_begin|>",
|
||||
delta_text="<|tool_calls_section_begin|>",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[section_begin_id],
|
||||
delta_token_ids=[section_begin_id],
|
||||
request=None,
|
||||
)
|
||||
assert kimi_k2_tool_parser.in_tool_section is True
|
||||
|
||||
# Some content in tool section
|
||||
result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="<|tool_calls_section_begin|>",
|
||||
current_text="<|tool_calls_section_begin|> partial content",
|
||||
delta_text=" partial content",
|
||||
previous_token_ids=[section_begin_id],
|
||||
current_token_ids=[section_begin_id, 10, 11],
|
||||
delta_token_ids=[10, 11],
|
||||
request=None,
|
||||
)
|
||||
# Content should be suppressed
|
||||
assert result2.content == "" or result2.content is None
|
||||
|
||||
# Stream ends (EOF) - no more deltas, no section_end marker
|
||||
# Simulate this by manually checking state and resetting
|
||||
# (In real usage, the request handler would call reset_streaming_state)
|
||||
assert kimi_k2_tool_parser.in_tool_section is True # Still in section
|
||||
|
||||
# Reset state (as would happen between requests)
|
||||
kimi_k2_tool_parser.reset_streaming_state()
|
||||
|
||||
# Verify clean slate
|
||||
assert kimi_k2_tool_parser.in_tool_section is False
|
||||
assert kimi_k2_tool_parser.token_buffer == ""
|
||||
|
||||
# Next request should work normally
|
||||
result3 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="",
|
||||
current_text="New reasoning",
|
||||
delta_text="New reasoning",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[20, 21],
|
||||
delta_token_ids=[20, 21],
|
||||
request=None,
|
||||
)
|
||||
assert result3 is not None
|
||||
assert result3.content == "New reasoning"
|
||||
|
||||
|
||||
def test_same_chunk_begin_and_end_markers(kimi_k2_tool_parser):
|
||||
"""
|
||||
CRITICAL TEST: Verify that when both section_begin and section_end
|
||||
markers appear in the SAME chunk, the parser correctly:
|
||||
1. Enters the tool section
|
||||
2. Immediately exits the tool section
|
||||
3. Does NOT get stuck in in_tool_section=True state
|
||||
|
||||
This tests the bug fix where elif was changed to if to handle
|
||||
both state transitions in a single delta.
|
||||
"""
|
||||
kimi_k2_tool_parser.reset_streaming_state()
|
||||
|
||||
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
|
||||
section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>")
|
||||
|
||||
# Single chunk with both markers (e.g., empty tool section)
|
||||
combined_delta = "<|tool_calls_section_begin|><|tool_calls_section_end|>"
|
||||
|
||||
result = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="Some reasoning ",
|
||||
current_text="Some reasoning " + combined_delta,
|
||||
delta_text=combined_delta,
|
||||
previous_token_ids=[1, 2],
|
||||
current_token_ids=[1, 2, section_begin_id, section_end_id],
|
||||
delta_token_ids=[section_begin_id, section_end_id],
|
||||
request=None,
|
||||
)
|
||||
|
||||
# CRITICAL: Parser should NOT be stuck in tool section
|
||||
assert kimi_k2_tool_parser.in_tool_section is False, (
|
||||
"Parser stuck in tool section after processing both begin/end in same chunk. "
|
||||
"This indicates the elif bug was not fixed."
|
||||
)
|
||||
|
||||
# Result should be empty or contain only stripped content
|
||||
assert result is not None
|
||||
assert result.content == "" or result.content is None
|
||||
|
||||
# Verify subsequent content streams correctly (not suppressed)
|
||||
result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="Some reasoning " + combined_delta,
|
||||
current_text="Some reasoning " + combined_delta + " More reasoning",
|
||||
delta_text=" More reasoning",
|
||||
previous_token_ids=[1, 2, section_begin_id, section_end_id],
|
||||
current_token_ids=[1, 2, section_begin_id, section_end_id, 10, 11],
|
||||
delta_token_ids=[10, 11],
|
||||
request=None,
|
||||
)
|
||||
|
||||
# This content should NOT be suppressed (we're out of tool section)
|
||||
assert result2 is not None
|
||||
assert result2.content == " More reasoning"
|
||||
|
||||
|
||||
def test_same_chunk_begin_content_end_markers(kimi_k2_tool_parser):
|
||||
"""
|
||||
Test the same-chunk scenario with actual content between markers.
|
||||
Example: <|tool_calls_section_begin|> text <|tool_calls_section_end|>
|
||||
all arriving in one delta. The key is that the state machine correctly
|
||||
transitions in and out within the same chunk.
|
||||
"""
|
||||
kimi_k2_tool_parser.reset_streaming_state()
|
||||
|
||||
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
|
||||
section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>")
|
||||
|
||||
# Chunk with begin, some whitespace/noise, and end all together
|
||||
# This simulates a tool section that opens and closes in the same chunk
|
||||
combined_delta = "<|tool_calls_section_begin|> <|tool_calls_section_end|>"
|
||||
|
||||
_result = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="Reasoning ",
|
||||
current_text="Reasoning " + combined_delta,
|
||||
delta_text=combined_delta,
|
||||
previous_token_ids=[1],
|
||||
current_token_ids=[1, section_begin_id, 100, section_end_id],
|
||||
delta_token_ids=[section_begin_id, 100, section_end_id],
|
||||
request=None,
|
||||
)
|
||||
|
||||
# Parser should exit cleanly (not stuck in tool section)
|
||||
assert kimi_k2_tool_parser.in_tool_section is False
|
||||
|
||||
# Verify the fix: next content should stream normally, not be suppressed
|
||||
result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="Reasoning " + combined_delta,
|
||||
current_text="Reasoning " + combined_delta + " Done",
|
||||
delta_text=" Done",
|
||||
previous_token_ids=[1, section_begin_id, 100, section_end_id],
|
||||
current_token_ids=[1, section_begin_id, 100, section_end_id, 200],
|
||||
delta_token_ids=[200],
|
||||
request=None,
|
||||
)
|
||||
|
||||
# Content after section should be returned (not suppressed)
|
||||
assert result2 is not None
|
||||
assert result2.content == " Done"
|
||||
|
||||
|
||||
def test_tool_call_end_and_section_end_same_chunk(kimi_k2_tool_parser):
|
||||
"""
|
||||
CRITICAL TEST (P1): Verify that when both <|tool_call_end|> and
|
||||
<|tool_calls_section_end|> appear in the SAME chunk, the parser:
|
||||
1. Processes the tool_call_end first (emits final arguments)
|
||||
2. THEN exits the section
|
||||
3. Does NOT drop the final tool call update
|
||||
4. Does NOT leak special tokens into reasoning
|
||||
|
||||
This tests the deferred section exit fix.
|
||||
"""
|
||||
kimi_k2_tool_parser.reset_streaming_state()
|
||||
|
||||
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
|
||||
section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>")
|
||||
tool_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_call_begin|>")
|
||||
tool_end_id = kimi_k2_tool_parser.vocab.get("<|tool_call_end|>")
|
||||
|
||||
# Simulate a streaming sequence for a SHORT tool call (all in one chunk):
|
||||
# 1. Reasoning text
|
||||
result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="",
|
||||
current_text="Let me help. ",
|
||||
delta_text="Let me help. ",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[1, 2],
|
||||
delta_token_ids=[1, 2],
|
||||
request=None,
|
||||
)
|
||||
assert result1 is not None
|
||||
assert result1.content == "Let me help. "
|
||||
|
||||
# 2. Section begin
|
||||
_result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="Let me help. ",
|
||||
current_text="Let me help. <|tool_calls_section_begin|>",
|
||||
delta_text="<|tool_calls_section_begin|>",
|
||||
previous_token_ids=[1, 2],
|
||||
current_token_ids=[1, 2, section_begin_id],
|
||||
delta_token_ids=[section_begin_id],
|
||||
request=None,
|
||||
)
|
||||
assert kimi_k2_tool_parser.in_tool_section is True
|
||||
|
||||
# 3. Tool call begin + full content + tool_end + section_end ALL IN ONE CHUNK
|
||||
# This is the critical scenario for short tool calls
|
||||
combined = (
|
||||
'<|tool_call_begin|>get_weather:0 <|tool_call_argument_begin|> {"city": "Paris"} '
|
||||
"<|tool_call_end|><|tool_calls_section_end|>"
|
||||
)
|
||||
|
||||
# Build up the previous text gradually to simulate realistic streaming
|
||||
prev_text = "Let me help. <|tool_calls_section_begin|>"
|
||||
curr_text = prev_text + combined
|
||||
|
||||
result3 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text=prev_text,
|
||||
current_text=curr_text,
|
||||
delta_text=combined,
|
||||
previous_token_ids=[1, 2, section_begin_id],
|
||||
current_token_ids=[
|
||||
1,
|
||||
2,
|
||||
section_begin_id,
|
||||
tool_begin_id,
|
||||
10,
|
||||
11,
|
||||
12,
|
||||
tool_end_id,
|
||||
section_end_id,
|
||||
],
|
||||
delta_token_ids=[tool_begin_id, 10, 11, 12, tool_end_id, section_end_id],
|
||||
request=None,
|
||||
)
|
||||
|
||||
# CRITICAL: Parser should have exited section AFTER processing tool
|
||||
assert kimi_k2_tool_parser.in_tool_section is False
|
||||
|
||||
# Tool call should have been emitted (not dropped)
|
||||
# The result might be the tool name or None depending on state, but
|
||||
# importantly, it shouldn't be returning the literal tokens as content
|
||||
|
||||
if result3 is not None and result3.content is not None:
|
||||
# Verify no special tokens leaked into content
|
||||
assert "<|tool_call_end|>" not in result3.content
|
||||
assert "<|tool_calls_section_end|>" not in result3.content
|
||||
|
||||
# 4. Verify subsequent content streams normally
|
||||
result4 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text=curr_text,
|
||||
current_text=curr_text + " Done",
|
||||
delta_text=" Done",
|
||||
previous_token_ids=[
|
||||
1,
|
||||
2,
|
||||
section_begin_id,
|
||||
tool_begin_id,
|
||||
10,
|
||||
11,
|
||||
12,
|
||||
tool_end_id,
|
||||
section_end_id,
|
||||
],
|
||||
current_token_ids=[
|
||||
1,
|
||||
2,
|
||||
section_begin_id,
|
||||
tool_begin_id,
|
||||
10,
|
||||
11,
|
||||
12,
|
||||
tool_end_id,
|
||||
section_end_id,
|
||||
20,
|
||||
],
|
||||
delta_token_ids=[20],
|
||||
request=None,
|
||||
)
|
||||
|
||||
# Content after tool section should stream normally
|
||||
assert result4 is not None
|
||||
assert result4.content == " Done"
|
||||
1225
tests/tool_parsers/test_minimax_tool_parser.py
Normal file
1225
tests/tool_parsers/test_minimax_tool_parser.py
Normal file
File diff suppressed because it is too large
Load Diff
860
tests/tool_parsers/test_mistral_tool_parser.py
Normal file
860
tests/tool_parsers/test_mistral_tool_parser.py
Normal file
@@ -0,0 +1,860 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
|
||||
import partial_json_parser
|
||||
import pytest
|
||||
from mistral_common.protocol.instruct.messages import AssistantMessage
|
||||
from mistral_common.protocol.instruct.request import InstructRequest
|
||||
from mistral_common.protocol.instruct.tool_calls import FunctionCall, ToolCall
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.openai.protocol import DeltaMessage, DeltaToolCall
|
||||
from vllm.tokenizers import TokenizerLike, get_tokenizer
|
||||
from vllm.tokenizers.detokenizer_utils import detokenize_incrementally
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.tool_parsers.mistral_tool_parser import MistralToolParser
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mistral_pre_v11_tokenizer():
|
||||
MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
|
||||
return get_tokenizer(tokenizer_name=MODEL)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mistral_tokenizer():
|
||||
MODEL = "mistralai/Mistral-Small-3.2-24B-Instruct-2506"
|
||||
return get_tokenizer(tokenizer_name=MODEL, tokenizer_mode="mistral")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mistral_pre_v11_tool_parser(mistral_pre_v11_tokenizer):
|
||||
return MistralToolParser(mistral_pre_v11_tokenizer)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mistral_tool_parser(mistral_tokenizer):
|
||||
return MistralToolParser(mistral_tokenizer)
|
||||
|
||||
|
||||
def assert_tool_calls(
|
||||
actual_tool_calls: list[ToolCall] | list[DeltaToolCall],
|
||||
expected_tool_calls: list[ToolCall],
|
||||
):
|
||||
assert len(actual_tool_calls) == len(expected_tool_calls)
|
||||
|
||||
for actual_tool_call, expected_tool_call in zip(
|
||||
actual_tool_calls, expected_tool_calls
|
||||
):
|
||||
assert isinstance(actual_tool_call.id, str)
|
||||
assert len(actual_tool_call.id) == 9
|
||||
|
||||
if isinstance(actual_tool_call, ToolCall):
|
||||
assert actual_tool_call.type == "function"
|
||||
elif isinstance(actual_tool_call, DeltaToolCall):
|
||||
assert actual_tool_call.function is not None
|
||||
assert actual_tool_call.function.name is not None
|
||||
assert actual_tool_call.function.arguments is not None
|
||||
assert actual_tool_call.function is not None
|
||||
assert actual_tool_call.function.name == expected_tool_call.function.name, (
|
||||
f"got wrong function name:${actual_tool_call.function.name}"
|
||||
)
|
||||
assert (
|
||||
actual_tool_call.function.arguments == expected_tool_call.function.arguments
|
||||
), f"got wrong function argument:${actual_tool_call.function.arguments}"
|
||||
|
||||
|
||||
def fix_tool_call_tokenization(
|
||||
tokens: list[int],
|
||||
mistral_tool_parser: MistralToolParser,
|
||||
mistral_tokenizer: TokenizerLike,
|
||||
):
|
||||
"""
|
||||
Replaces the textual token sequence for [TOOL_CALLS]
|
||||
with its single special token ID.
|
||||
"""
|
||||
textual_tool_call_token_ids = mistral_tokenizer.encode(
|
||||
text=mistral_tool_parser.bot_token,
|
||||
add_special_tokens=False,
|
||||
)
|
||||
# textual_tool_call_token_ids must not contain special tokens like bos, eos etc
|
||||
special_tool_call_token_ids = [mistral_tool_parser.bot_token_id]
|
||||
|
||||
# If the input is too short to contain the sequence, no replacement is possible
|
||||
if not tokens or len(tokens) < len(textual_tool_call_token_ids):
|
||||
return tokens
|
||||
|
||||
result_tokens = []
|
||||
i = 0
|
||||
target_len = len(textual_tool_call_token_ids)
|
||||
|
||||
while i < len(tokens):
|
||||
# Check if the slice from the current position matches the target sequence
|
||||
if tokens[i : i + target_len] == textual_tool_call_token_ids:
|
||||
# If it matches, add the replacement and jump the index forward
|
||||
result_tokens.extend(special_tool_call_token_ids)
|
||||
i += target_len
|
||||
else:
|
||||
# Otherwise, just add the current token and move to the next one
|
||||
result_tokens.append(tokens[i])
|
||||
i += 1
|
||||
|
||||
return result_tokens
|
||||
|
||||
|
||||
def stream_delta_message_generator(
|
||||
mistral_tool_parser: MistralToolParser,
|
||||
mistral_tokenizer: TokenizerLike,
|
||||
model_output: str | None,
|
||||
tools: list[tuple[str, str]] | None,
|
||||
) -> Generator[DeltaMessage, None, None]:
|
||||
if (
|
||||
isinstance(mistral_tokenizer, MistralTokenizer)
|
||||
and mistral_tokenizer.version >= 11
|
||||
):
|
||||
# With the newer versions of the tokenizer,
|
||||
# we cannot tokenize free text
|
||||
# so we need to create a list of messages to get tokenized
|
||||
assert tools is not None
|
||||
assistant_msg = AssistantMessage(
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name=name,
|
||||
arguments=arg,
|
||||
)
|
||||
)
|
||||
for (name, arg) in tools
|
||||
],
|
||||
)
|
||||
request = InstructRequest(
|
||||
messages=[assistant_msg],
|
||||
)
|
||||
all_token_ids = mistral_tokenizer.instruct.encode_instruct(request).tokens
|
||||
else:
|
||||
# Older versions of the tokenizer are
|
||||
# able to encode directly the model's output (free text) into tokens
|
||||
assert model_output is not None
|
||||
all_token_ids = mistral_tokenizer.encode(model_output, add_special_tokens=False)
|
||||
|
||||
all_token_ids = fix_tool_call_tokenization(
|
||||
all_token_ids, mistral_tool_parser, mistral_tokenizer
|
||||
)
|
||||
|
||||
previous_text = ""
|
||||
previous_tokens = None
|
||||
prefix_offset = 0
|
||||
read_offset = 0
|
||||
for i, delta_token in enumerate(all_token_ids):
|
||||
delta_token_ids = [delta_token]
|
||||
previous_token_ids = all_token_ids[:i]
|
||||
current_token_ids = all_token_ids[: i + 1]
|
||||
|
||||
(new_tokens, delta_text, new_prefix_offset, new_read_offset) = (
|
||||
detokenize_incrementally(
|
||||
tokenizer=mistral_tokenizer,
|
||||
all_input_ids=current_token_ids,
|
||||
prev_tokens=previous_tokens,
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
skip_special_tokens=isinstance(mistral_tokenizer, MistralTokenizer),
|
||||
spaces_between_special_tokens=True,
|
||||
)
|
||||
)
|
||||
|
||||
current_text = previous_text + delta_text
|
||||
|
||||
delta_message = mistral_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text,
|
||||
current_text,
|
||||
delta_text,
|
||||
previous_token_ids,
|
||||
current_token_ids,
|
||||
delta_token_ids,
|
||||
request=None, # type: ignore[arg-type]
|
||||
)
|
||||
if delta_message:
|
||||
yield delta_message
|
||||
|
||||
previous_text = current_text
|
||||
previous_tokens = (
|
||||
previous_tokens + new_tokens if previous_tokens else new_tokens
|
||||
)
|
||||
prefix_offset = new_prefix_offset
|
||||
read_offset = new_read_offset
|
||||
|
||||
|
||||
def test_extract_tool_calls_no_tools(mistral_pre_v11_tool_parser):
|
||||
model_output = "This is a test"
|
||||
extracted_tool_calls = mistral_pre_v11_tool_parser.extract_tool_calls(
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
assert not extracted_tool_calls.tools_called
|
||||
assert extracted_tool_calls.tool_calls == []
|
||||
assert extracted_tool_calls.content == model_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"single_tool_add",
|
||||
"single_tool_weather",
|
||||
"argument_before_name",
|
||||
"argument_before_name_and_name_in_argument",
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
(
|
||||
"""[TOOL_CALLS][{"name": "add", "arguments":{"a": 3.5, "b": 4}}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="add", arguments=json.dumps({"a": 3.5, "b": 4})
|
||||
)
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS] [{"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS] [{"arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS] [{"arguments":{"name": "John Doe"}, "name": "get_age"}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_age",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"name": "John Doe",
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls_pre_v11_tokenizer(
|
||||
mistral_pre_v11_tool_parser, model_output, expected_tool_calls, expected_content
|
||||
):
|
||||
extracted_tool_calls = mistral_pre_v11_tool_parser.extract_tool_calls(
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
assert extracted_tool_calls.tools_called
|
||||
|
||||
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
|
||||
|
||||
assert extracted_tool_calls.content == expected_content
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"single_tool_add",
|
||||
"single_tool_weather",
|
||||
"multiple_tool_calls",
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
(
|
||||
"""[TOOL_CALLS]add_this_and_that{"a": 3.5, "b": 4}""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="add_this_and_that",
|
||||
arguments=json.dumps({"a": 3.5, "b": 4}),
|
||||
)
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS]get_current_weather{"city": "San Francisco", "state": "CA", "unit": "celsius"}""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS]add{"a": 3.5, "b": 4}[TOOL_CALLS]multiply{"a": 3, "b": 6}""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="add", arguments=json.dumps({"a": 3.5, "b": 4})
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="multiply", arguments=json.dumps({"a": 3, "b": 6})
|
||||
)
|
||||
),
|
||||
],
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls(
|
||||
mistral_tool_parser, model_output, expected_tool_calls, expected_content
|
||||
):
|
||||
extracted_tool_calls = mistral_tool_parser.extract_tool_calls(
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
assert extracted_tool_calls.tools_called
|
||||
|
||||
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
|
||||
|
||||
assert extracted_tool_calls.content == expected_content
|
||||
|
||||
|
||||
def _test_extract_tool_calls_streaming(
|
||||
tool_parser, tokenizer, model_output, tools, expected_tool_calls, expected_content
|
||||
):
|
||||
other_content: str = ""
|
||||
function_names: list[str] = []
|
||||
function_args_strs: list[str] = []
|
||||
tool_call_idx: int = -1
|
||||
tool_call_ids: list[str | None] = []
|
||||
|
||||
for delta_message in stream_delta_message_generator(
|
||||
tool_parser, tokenizer, model_output, tools
|
||||
):
|
||||
# role should never be streamed from tool parser
|
||||
assert not delta_message.role
|
||||
|
||||
if delta_message.content:
|
||||
other_content += delta_message.content
|
||||
|
||||
streamed_tool_calls = delta_message.tool_calls
|
||||
|
||||
if streamed_tool_calls and len(streamed_tool_calls) > 0:
|
||||
# make sure only one diff is present - correct even for parallel
|
||||
assert len(streamed_tool_calls) == 1
|
||||
tool_call = streamed_tool_calls[0]
|
||||
|
||||
assert len(tool_parser.prev_tool_call_arr) > 0
|
||||
|
||||
# if a new tool is being called, set up empty arguments
|
||||
if tool_call.index != tool_call_idx:
|
||||
tool_call_idx = tool_call.index
|
||||
function_args_strs.append("")
|
||||
tool_call_ids.append(None)
|
||||
|
||||
# if a tool call ID is streamed, make sure one hasn't been already
|
||||
if tool_call.id and not tool_call_ids[tool_call.index]:
|
||||
tool_call_ids[tool_call.index] = tool_call.id
|
||||
|
||||
# if parts of the function start being streamed
|
||||
if tool_call.function:
|
||||
# if the function name is defined, set it. it should be streamed
|
||||
# IN ENTIRETY, exactly one time.
|
||||
if tool_call.function.name:
|
||||
assert isinstance(tool_call.function.name, str)
|
||||
function_names.append(tool_call.function.name)
|
||||
|
||||
if tool_call.function.arguments:
|
||||
# make sure they're a string and then add them to the list
|
||||
assert isinstance(tool_call.function.arguments, str)
|
||||
|
||||
function_args_strs[tool_call.index] += tool_call.function.arguments
|
||||
|
||||
assert other_content == expected_content
|
||||
|
||||
actual_tool_calls = [
|
||||
ToolCall(
|
||||
id=tool_call_id,
|
||||
function=FunctionCall(
|
||||
name=function_name,
|
||||
arguments=partial_json_parser.ensure_json(
|
||||
function_args_str, Allow.OBJ | Allow.STR
|
||||
),
|
||||
),
|
||||
)
|
||||
for tool_call_id, function_name, function_args_str in zip(
|
||||
tool_call_ids, function_names, function_args_strs
|
||||
)
|
||||
]
|
||||
assert_tool_calls(actual_tool_calls, expected_tool_calls)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"no_tools",
|
||||
"single_tool_add",
|
||||
"single_tool_add_strings",
|
||||
"single_tool_weather",
|
||||
"argument_before_name",
|
||||
"argument_before_name_and_name_in_argument",
|
||||
"multiple_tools",
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
("""This is a test""", [], """This is a test"""),
|
||||
(
|
||||
"""[TOOL_CALLS] [ {"name":"add" , "arguments" : {"a": 3, "b": 4} } ]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="add", arguments=json.dumps({"a": 3, "b": 4})
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS] [{"name": "add", "arguments":{"a": "3", "b": "4"}}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="add", arguments=json.dumps({"a": "3", "b": "4"})
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS] [{"arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS] [{"arguments": {"name": "John Doe"}, "name": "get_age"}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_age",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"name": "John Doe",
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS] [{"name": "add", "arguments": {"a": 3.5, "b": 4}}, {"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="add", arguments=json.dumps({"a": 3.5, "b": 4})
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
|
||||
),
|
||||
)
|
||||
),
|
||||
],
|
||||
"",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls_streaming_pre_v11_tokenizer(
|
||||
mistral_pre_v11_tool_parser,
|
||||
mistral_pre_v11_tokenizer,
|
||||
model_output,
|
||||
expected_tool_calls,
|
||||
expected_content,
|
||||
):
|
||||
_test_extract_tool_calls_streaming(
|
||||
mistral_pre_v11_tool_parser,
|
||||
mistral_pre_v11_tokenizer,
|
||||
model_output,
|
||||
None,
|
||||
expected_tool_calls,
|
||||
expected_content,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"single_tool_add",
|
||||
"single_tool_add_strings",
|
||||
"multiple_tools",
|
||||
],
|
||||
argnames=["tools", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
(
|
||||
[("add", '{"a": 3, "b": 4}')],
|
||||
# [TOOL_CALLS]add{"a": 3, "b": 4}
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="add", arguments=json.dumps({"a": 3, "b": 4})
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
[("add_two_strings", '{"a": "3", "b": "4"}')],
|
||||
# [TOOL_CALLS]add_two_strings{"a": "3", "b": "4"}
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="add_two_strings",
|
||||
arguments=json.dumps({"a": "3", "b": "4"}),
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
[
|
||||
("add", '{"a": 3.5, "b": 4}'),
|
||||
(
|
||||
"get_current_weather",
|
||||
'{"city": "San Francisco", "state": "CA", "unit": "celsius"}', # noqa: E501
|
||||
),
|
||||
],
|
||||
# [TOOL_CALLS]add{"a": 3.5, "b": 4}[TOOL_CALLS]get_current_weather{"city": "San Francisco", "state": "CA", "unit": "celsius"} # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="add", arguments=json.dumps({"a": 3.5, "b": 4})
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
|
||||
),
|
||||
)
|
||||
),
|
||||
],
|
||||
"",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls_streaming(
|
||||
mistral_tool_parser,
|
||||
mistral_tokenizer,
|
||||
tools,
|
||||
expected_tool_calls,
|
||||
expected_content,
|
||||
):
|
||||
_test_extract_tool_calls_streaming(
|
||||
mistral_tool_parser,
|
||||
mistral_tokenizer,
|
||||
None,
|
||||
tools,
|
||||
expected_tool_calls,
|
||||
expected_content,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"single_tool_add",
|
||||
"single_tool_weather",
|
||||
"multiple_tool_calls",
|
||||
"content_before_tool",
|
||||
"complex",
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
(
|
||||
"""[TOOL_CALLS]add_this_and_that{"a": 3.5, "b": 4}""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="add_this_and_that",
|
||||
arguments=json.dumps({"a": 3.5, "b": 4}),
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS]get_current_weather{"city": "San Francisco", "state": "CA", "unit": "celsius"}""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS]add{"a": 3.5, "b": 4}[TOOL_CALLS]multiply{"a": 3, "b": 6}""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="add", arguments=json.dumps({"a": 3.5, "b": 4})
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="multiply", arguments=json.dumps({"a": 3, "b": 6})
|
||||
)
|
||||
),
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
# Additional content should not be after the tool calls
|
||||
"""bla[TOOL_CALLS]add_this_and_that{"a": 3.5, "b": 4}""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="add_this_and_that",
|
||||
arguments=json.dumps({"a": 3.5, "b": 4}),
|
||||
)
|
||||
)
|
||||
],
|
||||
"bla",
|
||||
),
|
||||
(
|
||||
# Complex
|
||||
"""[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')"}""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="bash",
|
||||
arguments=json.dumps(
|
||||
{"command": "print(\"hello world!\")\nre.compile(r'{}')"}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls_streaming_one_chunk(
|
||||
mistral_tool_parser,
|
||||
mistral_tokenizer,
|
||||
model_output,
|
||||
expected_tool_calls,
|
||||
expected_content,
|
||||
):
|
||||
if isinstance(mistral_tokenizer, MistralTokenizer):
|
||||
all_token_ids = mistral_tokenizer.encode(model_output)
|
||||
else:
|
||||
all_token_ids = mistral_tokenizer.encode(model_output, add_special_tokens=False)
|
||||
all_token_ids = fix_tool_call_tokenization(
|
||||
all_token_ids, mistral_tool_parser, mistral_tokenizer
|
||||
)
|
||||
|
||||
delta_message = mistral_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="",
|
||||
current_text=model_output,
|
||||
delta_text=model_output,
|
||||
previous_token_ids=[],
|
||||
current_token_ids=all_token_ids,
|
||||
delta_token_ids=all_token_ids,
|
||||
request=None,
|
||||
) # type: ignore[arg-type]
|
||||
assert isinstance(delta_message, DeltaMessage)
|
||||
assert len(delta_message.tool_calls) == len(expected_tool_calls)
|
||||
|
||||
assert_tool_calls(delta_message.tool_calls, expected_tool_calls)
|
||||
|
||||
if delta_message.content is None:
|
||||
assert expected_content == ""
|
||||
else:
|
||||
assert delta_message.content == expected_content
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"no_tools",
|
||||
"single_tool_add",
|
||||
"single_tool_add_strings",
|
||||
"single_tool_weather",
|
||||
"argument_before_name",
|
||||
"argument_before_name_and_name_in_argument",
|
||||
"multiple_tools",
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
("""This is a test""", [], """This is a test"""),
|
||||
(
|
||||
"""[TOOL_CALLS] [ {"name":"add" , "arguments" : {"a": 3, "b": 4} } ]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="add", arguments=json.dumps({"a": 3, "b": 4})
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS] [{"name": "add", "arguments":{"a": "3", "b": "4"}}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="add", arguments=json.dumps({"a": "3", "b": "4"})
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS] [{"arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS] [{"arguments": {"name": "John Doe"}, "name": "get_age"}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_age",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"name": "John Doe",
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS] [{"arguments": {"a": 3.5, "b": 4}, "name": "add"}, {"arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="add", arguments=json.dumps({"a": 3.5, "b": 4})
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
|
||||
),
|
||||
)
|
||||
),
|
||||
],
|
||||
"",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls_streaming_pre_v11_tokenizer_one_chunk(
|
||||
mistral_pre_v11_tool_parser,
|
||||
mistral_pre_v11_tokenizer,
|
||||
model_output,
|
||||
expected_tool_calls,
|
||||
expected_content,
|
||||
):
|
||||
if isinstance(mistral_pre_v11_tokenizer, MistralTokenizer):
|
||||
all_token_ids = mistral_pre_v11_tokenizer.encode(model_output)
|
||||
else:
|
||||
all_token_ids = mistral_pre_v11_tokenizer.encode(
|
||||
model_output, add_special_tokens=False
|
||||
)
|
||||
all_token_ids = fix_tool_call_tokenization(
|
||||
all_token_ids, mistral_pre_v11_tool_parser, mistral_pre_v11_tokenizer
|
||||
)
|
||||
|
||||
delta_message = mistral_pre_v11_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="",
|
||||
current_text=model_output,
|
||||
delta_text=model_output,
|
||||
previous_token_ids=[],
|
||||
current_token_ids=all_token_ids,
|
||||
delta_token_ids=all_token_ids,
|
||||
request=None,
|
||||
) # type: ignore[arg-type]
|
||||
assert isinstance(delta_message, DeltaMessage)
|
||||
assert len(delta_message.tool_calls) == len(expected_tool_calls)
|
||||
|
||||
assert_tool_calls(delta_message.tool_calls, expected_tool_calls)
|
||||
|
||||
if delta_message.content is None:
|
||||
assert expected_content == ""
|
||||
else:
|
||||
assert delta_message.content == expected_content
|
||||
263
tests/tool_parsers/test_openai_tool_parser.py
Normal file
263
tests/tool_parsers/test_openai_tool_parser.py
Normal file
@@ -0,0 +1,263 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from openai_harmony import (
|
||||
Conversation,
|
||||
DeveloperContent,
|
||||
HarmonyEncodingName,
|
||||
Message,
|
||||
Role,
|
||||
SystemContent,
|
||||
load_harmony_encoding,
|
||||
)
|
||||
|
||||
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
from vllm.tool_parsers.openai_tool_parser import OpenAIToolParser
|
||||
|
||||
MODEL = "gpt2"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def openai_tokenizer():
|
||||
# The parser does not use the tokenizer, but the constructor requires it.
|
||||
return get_tokenizer(MODEL)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openai_tool_parser(openai_tokenizer):
|
||||
return OpenAIToolParser(openai_tokenizer)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def harmony_encoding():
|
||||
return load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
|
||||
|
||||
|
||||
def assert_tool_calls(
|
||||
actual_tool_calls: list[ToolCall],
|
||||
expected_tool_calls: list[ToolCall],
|
||||
):
|
||||
assert len(actual_tool_calls) == len(expected_tool_calls)
|
||||
|
||||
for actual_tool_call, expected_tool_call in zip(
|
||||
actual_tool_calls, expected_tool_calls
|
||||
):
|
||||
assert isinstance(actual_tool_call.id, str)
|
||||
assert len(actual_tool_call.id) > 16 # Default from protocol.py
|
||||
assert actual_tool_call.type == "function"
|
||||
assert actual_tool_call.function == expected_tool_call.function
|
||||
|
||||
|
||||
def test_extract_tool_calls_no_tools(openai_tool_parser, harmony_encoding):
|
||||
convo = Conversation.from_messages(
|
||||
[
|
||||
Message.from_role_and_content(
|
||||
Role.SYSTEM,
|
||||
SystemContent.new(),
|
||||
),
|
||||
Message.from_role_and_content(
|
||||
Role.DEVELOPER,
|
||||
DeveloperContent.new().with_instructions("Talk like a pirate!"),
|
||||
),
|
||||
Message.from_role_and_content(Role.USER, "Arrr, how be you?"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "This is a test"
|
||||
).with_channel("final"),
|
||||
]
|
||||
)
|
||||
token_ids = harmony_encoding.render_conversation_for_completion(
|
||||
convo, Role.ASSISTANT
|
||||
)
|
||||
extracted_info = openai_tool_parser.extract_tool_calls(
|
||||
"",
|
||||
request=None,
|
||||
token_ids=token_ids,
|
||||
)
|
||||
assert not extracted_info.tools_called
|
||||
assert extracted_info.tool_calls == []
|
||||
assert extracted_info.content == "This is a test"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tool_args",
|
||||
[
|
||||
'{"location": "Tokyo"}',
|
||||
'{\n"location": "Tokyo"\n}',
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls_single_tool(
|
||||
openai_tool_parser, harmony_encoding, tool_args
|
||||
):
|
||||
convo = Conversation.from_messages(
|
||||
[
|
||||
Message.from_role_and_content(Role.USER, "What is the weather in Tokyo?"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT,
|
||||
'User asks: "What is the weather in Tokyo?" We need to use get_current_weather tool.', # noqa: E501
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(Role.ASSISTANT, tool_args)
|
||||
.with_channel("commentary")
|
||||
.with_recipient("functions.get_current_weather")
|
||||
.with_content_type("json"),
|
||||
]
|
||||
)
|
||||
token_ids = harmony_encoding.render_conversation_for_completion(
|
||||
convo, Role.ASSISTANT
|
||||
)
|
||||
|
||||
extracted_info = openai_tool_parser.extract_tool_calls(
|
||||
"",
|
||||
request=None,
|
||||
token_ids=token_ids,
|
||||
)
|
||||
assert extracted_info.tools_called
|
||||
expected_tool_calls = [
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({"location": "Tokyo"}),
|
||||
)
|
||||
)
|
||||
]
|
||||
assert_tool_calls(extracted_info.tool_calls, expected_tool_calls)
|
||||
assert extracted_info.content is None
|
||||
|
||||
|
||||
def test_extract_tool_calls_multiple_tools(
|
||||
openai_tool_parser,
|
||||
harmony_encoding,
|
||||
):
|
||||
convo = Conversation.from_messages(
|
||||
[
|
||||
Message.from_role_and_content(
|
||||
Role.USER, "What is the weather in Tokyo based on where I'm at?"
|
||||
),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT,
|
||||
'User asks: "What is the weather in Tokyo?" based on their location. We need to use get_current_weather tool and get_user_location tool.', # noqa: E501
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(Role.ASSISTANT, '{"location": "Tokyo"}')
|
||||
.with_channel("commentary")
|
||||
.with_recipient("functions.get_current_weather")
|
||||
.with_content_type("json"),
|
||||
Message.from_role_and_content(Role.ASSISTANT, '{"location": "Tokyo"}')
|
||||
.with_channel("commentary")
|
||||
.with_recipient("functions.get_user_location")
|
||||
.with_content_type("json"),
|
||||
Message.from_role_and_content(Role.ASSISTANT, '{"location": "Tokyo"}')
|
||||
.with_channel("commentary")
|
||||
.with_recipient("functions.no_content_type"),
|
||||
Message.from_role_and_content(Role.ASSISTANT, "foo")
|
||||
.with_channel("commentary")
|
||||
.with_recipient("functions.not_json_no_content_type"),
|
||||
Message.from_role_and_content(Role.ASSISTANT, "{}")
|
||||
.with_channel("commentary")
|
||||
.with_recipient("functions.empty_args")
|
||||
.with_content_type("json"),
|
||||
Message.from_role_and_content(Role.ASSISTANT, "")
|
||||
.with_channel("commentary")
|
||||
.with_recipient("functions.no_args")
|
||||
.with_content_type("json"),
|
||||
]
|
||||
)
|
||||
token_ids = harmony_encoding.render_conversation_for_completion(
|
||||
convo,
|
||||
Role.ASSISTANT,
|
||||
)
|
||||
|
||||
extracted_info = openai_tool_parser.extract_tool_calls(
|
||||
"",
|
||||
request=None,
|
||||
token_ids=token_ids,
|
||||
)
|
||||
assert extracted_info.tools_called
|
||||
expected_tool_calls = [
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({"location": "Tokyo"}),
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_user_location",
|
||||
arguments=json.dumps({"location": "Tokyo"}),
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="no_content_type",
|
||||
arguments=json.dumps({"location": "Tokyo"}),
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="not_json_no_content_type",
|
||||
arguments="foo",
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="empty_args",
|
||||
arguments=json.dumps({}),
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="no_args",
|
||||
arguments="",
|
||||
)
|
||||
),
|
||||
]
|
||||
assert_tool_calls(extracted_info.tool_calls, expected_tool_calls)
|
||||
assert extracted_info.content is None
|
||||
|
||||
|
||||
def test_extract_tool_calls_with_content(
|
||||
openai_tool_parser,
|
||||
harmony_encoding,
|
||||
):
|
||||
final_content = "This tool call will get the weather."
|
||||
convo = Conversation.from_messages(
|
||||
[
|
||||
Message.from_role_and_content(
|
||||
Role.USER, "What is the weather in Tokyo based on where I'm at?"
|
||||
),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT,
|
||||
'User asks: "What is the weather in Tokyo?" based on their location. We need to use get_current_weather tool and get_user_location tool.', # noqa: E501
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(Role.ASSISTANT, '{"location": "Tokyo"}')
|
||||
.with_channel("commentary")
|
||||
.with_recipient("functions.get_current_weather")
|
||||
.with_content_type("json"),
|
||||
Message.from_role_and_content(Role.ASSISTANT, final_content).with_channel(
|
||||
"final"
|
||||
),
|
||||
]
|
||||
)
|
||||
token_ids = harmony_encoding.render_conversation_for_completion(
|
||||
convo,
|
||||
Role.ASSISTANT,
|
||||
)
|
||||
|
||||
extracted_info = openai_tool_parser.extract_tool_calls(
|
||||
"",
|
||||
request=None,
|
||||
token_ids=token_ids,
|
||||
)
|
||||
assert extracted_info.tools_called
|
||||
expected_tool_calls = [
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({"location": "Tokyo"}),
|
||||
)
|
||||
),
|
||||
]
|
||||
assert_tool_calls(extracted_info.tool_calls, expected_tool_calls)
|
||||
assert extracted_info.content == final_content
|
||||
976
tests/tool_parsers/test_qwen3coder_tool_parser.py
Normal file
976
tests/tool_parsers/test_qwen3coder_tool_parser.py
Normal file
@@ -0,0 +1,976 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionToolsParam,
|
||||
DeltaMessage,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.tokenizers import TokenizerLike, get_tokenizer
|
||||
from vllm.tokenizers.detokenizer_utils import detokenize_incrementally
|
||||
from vllm.tool_parsers.qwen3coder_tool_parser import (
|
||||
Qwen3CoderToolParser,
|
||||
)
|
||||
from vllm.tool_parsers.qwen3xml_tool_parser import Qwen3XMLToolParser
|
||||
|
||||
MODEL = "Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def qwen3_tokenizer():
|
||||
return get_tokenizer(tokenizer_name=MODEL)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def qwen3_tool_parser(qwen3_tokenizer):
|
||||
return Qwen3CoderToolParser(qwen3_tokenizer)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def qwen3_xml_tool_parser(qwen3_tokenizer):
|
||||
return Qwen3XMLToolParser(qwen3_tokenizer)
|
||||
|
||||
|
||||
@pytest.fixture(params=["xml"])
|
||||
def qwen3_tool_parser_parametrized(qwen3_tool_parser, qwen3_xml_tool_parser, request):
|
||||
"""Parameterized fixture that provides both parser types for testing"""
|
||||
if request.param == "original":
|
||||
return qwen3_tool_parser
|
||||
else:
|
||||
return qwen3_xml_tool_parser
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tools():
|
||||
return [
|
||||
ChatCompletionToolsParam(
|
||||
type="function",
|
||||
function={
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string", "description": "The city name"},
|
||||
"state": {"type": "string", "description": "The state code"},
|
||||
"unit": {"type": "string", "enum": ["fahrenheit", "celsius"]},
|
||||
},
|
||||
"required": ["city", "state"],
|
||||
},
|
||||
},
|
||||
),
|
||||
ChatCompletionToolsParam(
|
||||
type="function",
|
||||
function={
|
||||
"name": "calculate_area",
|
||||
"description": "Calculate area of a shape",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"shape": {"type": "string"},
|
||||
"dimensions": {"type": "object"},
|
||||
"precision": {"type": "integer"},
|
||||
},
|
||||
},
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def assert_tool_calls(
|
||||
actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall]
|
||||
):
|
||||
assert len(actual_tool_calls) == len(expected_tool_calls)
|
||||
|
||||
for actual_tool_call, expected_tool_call in zip(
|
||||
actual_tool_calls, expected_tool_calls
|
||||
):
|
||||
# Qwen3 parser doesn't generate IDs during extraction
|
||||
assert actual_tool_call.type == "function"
|
||||
assert actual_tool_call.function.name == expected_tool_call.function.name
|
||||
assert json.loads(actual_tool_call.function.arguments) == json.loads(
|
||||
expected_tool_call.function.arguments
|
||||
)
|
||||
|
||||
|
||||
def stream_delta_message_generator(
|
||||
qwen3_tool_parser,
|
||||
qwen3_tokenizer: TokenizerLike,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest | None = None,
|
||||
) -> Generator[DeltaMessage, None, None]:
|
||||
all_token_ids = qwen3_tokenizer.encode(model_output, add_special_tokens=False)
|
||||
|
||||
previous_text = ""
|
||||
previous_tokens = None
|
||||
prefix_offset = 0
|
||||
read_offset = 0
|
||||
for i, delta_token in enumerate(all_token_ids):
|
||||
delta_token_ids = [delta_token]
|
||||
previous_token_ids = all_token_ids[:i]
|
||||
current_token_ids = all_token_ids[: i + 1]
|
||||
|
||||
(new_tokens, delta_text, new_prefix_offset, new_read_offset) = (
|
||||
detokenize_incrementally(
|
||||
tokenizer=qwen3_tokenizer,
|
||||
all_input_ids=current_token_ids,
|
||||
prev_tokens=previous_tokens,
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
skip_special_tokens=False,
|
||||
spaces_between_special_tokens=True,
|
||||
)
|
||||
)
|
||||
|
||||
current_text = previous_text + delta_text
|
||||
|
||||
delta_message = qwen3_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text,
|
||||
current_text,
|
||||
delta_text,
|
||||
previous_token_ids,
|
||||
current_token_ids,
|
||||
delta_token_ids,
|
||||
request=request,
|
||||
)
|
||||
if delta_message:
|
||||
yield delta_message
|
||||
|
||||
previous_text = current_text
|
||||
previous_tokens = (
|
||||
previous_tokens + new_tokens if previous_tokens else new_tokens
|
||||
)
|
||||
prefix_offset = new_prefix_offset
|
||||
read_offset = new_read_offset
|
||||
|
||||
|
||||
def test_extract_tool_calls_no_tools(qwen3_tool_parser_parametrized):
|
||||
model_output = "This is a test response without any tool calls"
|
||||
extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls(
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
assert not extracted_tool_calls.tools_called
|
||||
assert extracted_tool_calls.tool_calls == []
|
||||
assert extracted_tool_calls.content == model_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"single_tool",
|
||||
"single_tool_with_content",
|
||||
"single_tool_multiline_param",
|
||||
"parallel_tools",
|
||||
"tool_with_typed_params",
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
(
|
||||
"""<tool_call>
|
||||
<function=get_current_weather>
|
||||
<parameter=city>
|
||||
Dallas
|
||||
</parameter>
|
||||
<parameter=state>
|
||||
TX
|
||||
</parameter>
|
||||
<parameter=unit>
|
||||
fahrenheit
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>""",
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
"""Sure! Let me check the weather for you.<tool_call>
|
||||
<function=get_current_weather>
|
||||
<parameter=city>
|
||||
Dallas
|
||||
</parameter>
|
||||
<parameter=state>
|
||||
TX
|
||||
</parameter>
|
||||
<parameter=unit>
|
||||
fahrenheit
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>""",
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"Sure! Let me check the weather for you.",
|
||||
),
|
||||
(
|
||||
"""<tool_call>
|
||||
<function=calculate_area>
|
||||
<parameter=shape>
|
||||
rectangle
|
||||
</parameter>
|
||||
<parameter=dimensions>
|
||||
{"width": 10,
|
||||
"height": 20}
|
||||
</parameter>
|
||||
<parameter=precision>
|
||||
2
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>""",
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="calculate_area",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"shape": "rectangle",
|
||||
"dimensions": {"width": 10, "height": 20},
|
||||
"precision": 2,
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
"""<tool_call>
|
||||
<function=get_current_weather>
|
||||
<parameter=city>
|
||||
Dallas
|
||||
</parameter>
|
||||
<parameter=state>
|
||||
TX
|
||||
</parameter>
|
||||
<parameter=unit>
|
||||
fahrenheit
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
<tool_call>
|
||||
<function=get_current_weather>
|
||||
<parameter=city>
|
||||
Orlando
|
||||
</parameter>
|
||||
<parameter=state>
|
||||
FL
|
||||
</parameter>
|
||||
<parameter=unit>
|
||||
fahrenheit
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>""",
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
|
||||
),
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "Orlando", "state": "FL", "unit": "fahrenheit"}
|
||||
),
|
||||
)
|
||||
),
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
"""Let me calculate that area for you.<tool_call>
|
||||
<function=calculate_area>
|
||||
<parameter=shape>
|
||||
circle
|
||||
</parameter>
|
||||
<parameter=dimensions>
|
||||
{"radius": 15.5}
|
||||
</parameter>
|
||||
<parameter=precision>
|
||||
3
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>""",
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="calculate_area",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"shape": "circle",
|
||||
"dimensions": {"radius": 15.5},
|
||||
"precision": 3,
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"Let me calculate that area for you.",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls(
|
||||
qwen3_tool_parser_parametrized,
|
||||
sample_tools,
|
||||
model_output,
|
||||
expected_tool_calls,
|
||||
expected_content,
|
||||
):
|
||||
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
|
||||
extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls(
|
||||
model_output, request=request
|
||||
)
|
||||
assert extracted_tool_calls.tools_called
|
||||
|
||||
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
|
||||
|
||||
assert extracted_tool_calls.content == expected_content
|
||||
|
||||
|
||||
def test_extract_tool_calls_fallback_no_tags(
|
||||
qwen3_tool_parser_parametrized, sample_tools
|
||||
):
|
||||
"""Test fallback parsing when XML tags are missing"""
|
||||
model_output = """<function=get_current_weather>
|
||||
<parameter=city>
|
||||
Dallas
|
||||
</parameter>
|
||||
<parameter=state>
|
||||
TX
|
||||
</parameter>
|
||||
</function>"""
|
||||
|
||||
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
|
||||
extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls(
|
||||
model_output, request=request
|
||||
)
|
||||
|
||||
assert extracted_tool_calls.tools_called
|
||||
assert len(extracted_tool_calls.tool_calls) == 1
|
||||
assert extracted_tool_calls.tool_calls[0].function.name == "get_current_weather"
|
||||
|
||||
|
||||
def test_extract_tool_calls_type_conversion(qwen3_tool_parser_parametrized):
|
||||
"""Test parameter type conversion based on tool schema"""
|
||||
tools = [
|
||||
ChatCompletionToolsParam(
|
||||
type="function",
|
||||
function={
|
||||
"name": "test_types",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"int_param": {"type": "integer"},
|
||||
"float_param": {"type": "float"},
|
||||
"bool_param": {"type": "boolean"},
|
||||
"str_param": {"type": "string"},
|
||||
"obj_param": {"type": "object"},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
model_output = """<tool_call>
|
||||
<function=test_types>
|
||||
<parameter=int_param>
|
||||
42
|
||||
</parameter>
|
||||
<parameter=float_param>
|
||||
3.14
|
||||
</parameter>
|
||||
<parameter=bool_param>
|
||||
true
|
||||
</parameter>
|
||||
<parameter=str_param>
|
||||
hello world
|
||||
</parameter>
|
||||
<parameter=obj_param>
|
||||
{"key": "value"}
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>"""
|
||||
|
||||
request = ChatCompletionRequest(model=MODEL, messages=[], tools=tools)
|
||||
extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls(
|
||||
model_output, request=request
|
||||
)
|
||||
|
||||
args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments)
|
||||
assert args["int_param"] == 42
|
||||
assert args["float_param"] == 3.14
|
||||
assert args["bool_param"] is True
|
||||
assert args["str_param"] == "hello world"
|
||||
assert args["obj_param"] == {"key": "value"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"no_tools",
|
||||
"single_tool",
|
||||
"single_tool_with_content",
|
||||
"single_tool_multiline_param",
|
||||
"parallel_tools",
|
||||
"tool_with_typed_params", # Added this test case
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
("This is a test without tools", [], "This is a test without tools"),
|
||||
(
|
||||
"""<tool_call>
|
||||
<function=get_current_weather>
|
||||
<parameter=city>
|
||||
Dallas
|
||||
</parameter>
|
||||
<parameter=state>
|
||||
TX
|
||||
</parameter>
|
||||
<parameter=unit>
|
||||
fahrenheit
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>""",
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
"""Sure! Let me check the weather for you.<tool_call>
|
||||
<function=get_current_weather>
|
||||
<parameter=city>
|
||||
Dallas
|
||||
</parameter>
|
||||
<parameter=state>
|
||||
TX
|
||||
</parameter>
|
||||
<parameter=unit>
|
||||
fahrenheit
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>""",
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"Sure! Let me check the weather for you.",
|
||||
),
|
||||
(
|
||||
"""<tool_call>
|
||||
<function=calculate_area>
|
||||
<parameter=shape>
|
||||
rectangle
|
||||
</parameter>
|
||||
<parameter=dimensions>
|
||||
{"width": 10,
|
||||
"height": 20}
|
||||
</parameter>
|
||||
<parameter=precision>
|
||||
2
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>""",
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="calculate_area",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"shape": "rectangle",
|
||||
"dimensions": {"width": 10, "height": 20},
|
||||
"precision": 2,
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
"""<tool_call>
|
||||
<function=get_current_weather>
|
||||
<parameter=city>
|
||||
Dallas
|
||||
</parameter>
|
||||
<parameter=state>
|
||||
TX
|
||||
</parameter>
|
||||
<parameter=unit>
|
||||
fahrenheit
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
<tool_call>
|
||||
<function=get_current_weather>
|
||||
<parameter=city>
|
||||
Orlando
|
||||
</parameter>
|
||||
<parameter=state>
|
||||
FL
|
||||
</parameter>
|
||||
<parameter=unit>
|
||||
celsius
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>""",
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
|
||||
),
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "Orlando", "state": "FL", "unit": "celsius"}
|
||||
),
|
||||
)
|
||||
),
|
||||
],
|
||||
None,
|
||||
),
|
||||
# Added tool_with_typed_params test case
|
||||
(
|
||||
"""Let me calculate that area for you.<tool_call>
|
||||
<function=calculate_area>
|
||||
<parameter=shape>
|
||||
circle
|
||||
</parameter>
|
||||
<parameter=dimensions>
|
||||
{"radius": 15.5}
|
||||
</parameter>
|
||||
<parameter=precision>
|
||||
3
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>""",
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="calculate_area",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"shape": "circle",
|
||||
"dimensions": {"radius": 15.5},
|
||||
"precision": 3,
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"Let me calculate that area for you.",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls_streaming(
|
||||
qwen3_tool_parser_parametrized,
|
||||
qwen3_tokenizer,
|
||||
sample_tools,
|
||||
model_output,
|
||||
expected_tool_calls,
|
||||
expected_content,
|
||||
):
|
||||
"""Test incremental streaming behavior including typed parameters"""
|
||||
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
|
||||
|
||||
other_content = ""
|
||||
tool_states = {} # Track state per tool index
|
||||
|
||||
for delta_message in stream_delta_message_generator(
|
||||
qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output, request
|
||||
):
|
||||
# role should never be streamed from tool parser
|
||||
assert not delta_message.role
|
||||
|
||||
if delta_message.content:
|
||||
other_content += delta_message.content
|
||||
|
||||
if delta_message.tool_calls:
|
||||
for tool_call in delta_message.tool_calls:
|
||||
idx = tool_call.index
|
||||
|
||||
# Initialize state for new tool
|
||||
if idx not in tool_states:
|
||||
tool_states[idx] = {
|
||||
"id": None,
|
||||
"name": None,
|
||||
"arguments": "",
|
||||
"type": None,
|
||||
}
|
||||
|
||||
# First chunk should have id, name, and type
|
||||
if tool_call.id:
|
||||
tool_states[idx]["id"] = tool_call.id
|
||||
|
||||
if tool_call.type:
|
||||
assert tool_call.type == "function"
|
||||
tool_states[idx]["type"] = tool_call.type
|
||||
|
||||
if tool_call.function:
|
||||
if tool_call.function.name:
|
||||
# Should only be set once
|
||||
assert tool_states[idx]["name"] is None
|
||||
tool_states[idx]["name"] = tool_call.function.name
|
||||
|
||||
if tool_call.function.arguments is not None:
|
||||
# Accumulate arguments incrementally
|
||||
tool_states[idx]["arguments"] += tool_call.function.arguments
|
||||
|
||||
# Verify final content
|
||||
assert other_content == (expected_content or "") # Handle None case
|
||||
|
||||
# Verify we got all expected tool calls
|
||||
assert len(tool_states) == len(expected_tool_calls)
|
||||
assert len(qwen3_tool_parser_parametrized.prev_tool_call_arr) == len(
|
||||
expected_tool_calls
|
||||
)
|
||||
|
||||
# Verify each tool call
|
||||
for idx, expected_tool in enumerate(expected_tool_calls):
|
||||
state = tool_states[idx]
|
||||
assert state["id"] is not None
|
||||
assert state["type"] == "function"
|
||||
assert state["name"] == expected_tool.function.name
|
||||
|
||||
# Parse accumulated arguments
|
||||
arguments_str = state["arguments"]
|
||||
assert arguments_str is not None
|
||||
actual_args = json.loads(arguments_str)
|
||||
expected_args = json.loads(expected_tool.function.arguments)
|
||||
assert actual_args == expected_args
|
||||
|
||||
|
||||
def test_extract_tool_calls_missing_closing_parameter_tag(
|
||||
qwen3_tool_parser_parametrized, sample_tools
|
||||
):
|
||||
"""Test handling of missing closing </parameter> tag"""
|
||||
# Using get_current_weather from sample_tools but with malformed XML
|
||||
model_output = """Let me check the weather for you:
|
||||
<tool_call>
|
||||
<function=get_current_weather>
|
||||
<parameter=city>
|
||||
Dallas
|
||||
<parameter=state>
|
||||
TX
|
||||
</parameter>
|
||||
<parameter=unit>
|
||||
fahrenheit
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>"""
|
||||
|
||||
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
|
||||
extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls(
|
||||
model_output, request=request
|
||||
)
|
||||
|
||||
# The parser should handle the malformed XML gracefully
|
||||
assert extracted_tool_calls.tools_called
|
||||
assert len(extracted_tool_calls.tool_calls) == 1
|
||||
|
||||
# Verify the function name is correct
|
||||
assert extracted_tool_calls.tool_calls[0].function.name == "get_current_weather"
|
||||
|
||||
# Verify the arguments are parsed despite the missing closing tag
|
||||
args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments)
|
||||
assert "city" in args
|
||||
assert args["city"] == "Dallas"
|
||||
assert args["state"] == "TX"
|
||||
assert args["unit"] == "fahrenheit"
|
||||
|
||||
# Check that content before the tool call is preserved
|
||||
assert "Let me check the weather for you:" in extracted_tool_calls.content
|
||||
|
||||
|
||||
def test_extract_tool_calls_streaming_missing_closing_tag(
|
||||
qwen3_tool_parser_parametrized, qwen3_tokenizer, sample_tools
|
||||
):
|
||||
"""Test streaming with missing closing </parameter> tag"""
|
||||
# Using get_current_weather from sample_tools but with malformed XML
|
||||
model_output = """Let me check the weather for you:
|
||||
<tool_call>
|
||||
<function=get_current_weather>
|
||||
<parameter=city>
|
||||
Dallas
|
||||
<parameter=state>
|
||||
TX
|
||||
</parameter>
|
||||
<parameter=unit>
|
||||
fahrenheit
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>"""
|
||||
|
||||
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
|
||||
|
||||
other_content = ""
|
||||
tool_states = {}
|
||||
|
||||
for delta_message in stream_delta_message_generator(
|
||||
qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output, request
|
||||
):
|
||||
if delta_message.content:
|
||||
other_content += delta_message.content
|
||||
|
||||
if delta_message.tool_calls:
|
||||
for tool_call in delta_message.tool_calls:
|
||||
idx = tool_call.index
|
||||
|
||||
if idx not in tool_states:
|
||||
tool_states[idx] = {
|
||||
"id": None,
|
||||
"name": None,
|
||||
"arguments": "",
|
||||
"type": None,
|
||||
}
|
||||
|
||||
if tool_call.id:
|
||||
tool_states[idx]["id"] = tool_call.id
|
||||
|
||||
if tool_call.type:
|
||||
assert tool_call.type == "function"
|
||||
tool_states[idx]["type"] = tool_call.type
|
||||
|
||||
if tool_call.function:
|
||||
if tool_call.function.name:
|
||||
tool_states[idx]["name"] = tool_call.function.name
|
||||
|
||||
if tool_call.function.arguments is not None:
|
||||
tool_states[idx]["arguments"] += tool_call.function.arguments
|
||||
|
||||
# Verify content was streamed
|
||||
assert "Let me check the weather for you:" in other_content
|
||||
# Verify we got the tool call
|
||||
assert len(tool_states) == 1
|
||||
assert len(qwen3_tool_parser_parametrized.prev_tool_call_arr) == 1
|
||||
|
||||
state = tool_states[0]
|
||||
assert state["id"] is not None
|
||||
assert state["type"] == "function"
|
||||
assert state["name"] == "get_current_weather"
|
||||
|
||||
# Verify arguments were parsed correctly despite missing closing tag
|
||||
assert state["arguments"] is not None
|
||||
args = json.loads(state["arguments"])
|
||||
assert args["city"] == "Dallas"
|
||||
assert args["state"] == "TX"
|
||||
assert args["unit"] == "fahrenheit"
|
||||
|
||||
|
||||
def test_extract_tool_calls_streaming_incremental(
|
||||
qwen3_tool_parser_parametrized, qwen3_tokenizer, sample_tools
|
||||
):
|
||||
"""Test that streaming is truly incremental"""
|
||||
model_output = """I'll check the weather.<tool_call>
|
||||
<function=get_current_weather>
|
||||
<parameter=city>
|
||||
Dallas
|
||||
</parameter>
|
||||
<parameter=state>
|
||||
TX
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>"""
|
||||
|
||||
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
|
||||
|
||||
chunks = []
|
||||
for delta_message in stream_delta_message_generator(
|
||||
qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output, request
|
||||
):
|
||||
chunks.append(delta_message)
|
||||
|
||||
# Should have multiple chunks
|
||||
assert len(chunks) > 3
|
||||
|
||||
# First chunk(s) should be content
|
||||
assert chunks[0].content is not None
|
||||
assert chunks[0].tool_calls is None or chunks[0].tool_calls == []
|
||||
|
||||
# Should have a chunk with tool header (id, name, type)
|
||||
header_found = False
|
||||
for chunk in chunks:
|
||||
if chunk.tool_calls and chunk.tool_calls[0].id:
|
||||
header_found = True
|
||||
assert chunk.tool_calls[0].function.name == "get_current_weather"
|
||||
assert chunk.tool_calls[0].type == "function"
|
||||
# Empty initially
|
||||
assert chunk.tool_calls[0].function.arguments == ""
|
||||
break
|
||||
assert header_found
|
||||
|
||||
# Should have chunks with incremental arguments
|
||||
arg_chunks = []
|
||||
for chunk in chunks:
|
||||
if chunk.tool_calls and chunk.tool_calls[0].function.arguments:
|
||||
arg_chunks.append(chunk.tool_calls[0].function.arguments)
|
||||
|
||||
# Arguments should be streamed incrementally
|
||||
assert len(arg_chunks) > 1
|
||||
|
||||
# Concatenated arguments should form valid JSON
|
||||
full_args = "".join(arg_chunks)
|
||||
parsed_args = json.loads(full_args)
|
||||
assert parsed_args["city"] == "Dallas"
|
||||
assert parsed_args["state"] == "TX"
|
||||
|
||||
|
||||
def test_extract_tool_calls_complex_type_with_single_quote(
|
||||
qwen3_tool_parser_parametrized,
|
||||
):
|
||||
"""Test parameter type conversion based on tool schema"""
|
||||
tools = [
|
||||
ChatCompletionToolsParam(
|
||||
type="function",
|
||||
function={
|
||||
"name": "test_types",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"int_param": {"type": "integer"},
|
||||
"float_param": {"type": "float"},
|
||||
"bool_param": {"type": "boolean"},
|
||||
"str_param": {"type": "string"},
|
||||
"obj_param": {"type": "object"},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
model_output = """<tool_call>
|
||||
<function=test_types>
|
||||
<parameter=obj_param>
|
||||
{'key': 'value'}
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>"""
|
||||
|
||||
request = ChatCompletionRequest(model=MODEL, messages=[], tools=tools)
|
||||
extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls(
|
||||
model_output, request=request
|
||||
)
|
||||
|
||||
args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments)
|
||||
assert args["obj_param"] == {"key": "value"}
|
||||
|
||||
|
||||
def test_extract_tool_calls_streaming_missing_opening_tag(
|
||||
qwen3_tool_parser_parametrized, qwen3_tokenizer, sample_tools
|
||||
):
|
||||
"""Test streaming with missing opening <tool_call> tag
|
||||
|
||||
This tests that the streaming parser correctly handles
|
||||
tool calls that start directly with <function=...>
|
||||
"""
|
||||
model_output = """I'll check the weather for you.
|
||||
|
||||
<function=get_current_weather>
|
||||
<parameter=city>
|
||||
Dallas
|
||||
</parameter>
|
||||
<parameter=state>
|
||||
TX
|
||||
</parameter>
|
||||
<parameter=unit>
|
||||
fahrenheit
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>"""
|
||||
|
||||
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
|
||||
|
||||
other_content = ""
|
||||
tool_states = {}
|
||||
|
||||
for delta_message in stream_delta_message_generator(
|
||||
qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output, request
|
||||
):
|
||||
if delta_message.content:
|
||||
other_content += delta_message.content
|
||||
|
||||
if delta_message.tool_calls:
|
||||
for tool_call in delta_message.tool_calls:
|
||||
idx = tool_call.index
|
||||
|
||||
if idx not in tool_states:
|
||||
tool_states[idx] = {
|
||||
"id": None,
|
||||
"name": None,
|
||||
"arguments": "",
|
||||
"type": None,
|
||||
}
|
||||
|
||||
if tool_call.id:
|
||||
tool_states[idx]["id"] = tool_call.id
|
||||
|
||||
if tool_call.type:
|
||||
assert tool_call.type == "function"
|
||||
tool_states[idx]["type"] = tool_call.type
|
||||
|
||||
if tool_call.function:
|
||||
if tool_call.function.name:
|
||||
tool_states[idx]["name"] = tool_call.function.name
|
||||
|
||||
if tool_call.function.arguments is not None:
|
||||
tool_states[idx]["arguments"] += tool_call.function.arguments
|
||||
|
||||
# Verify content was streamed
|
||||
assert "I'll check the weather for you." in other_content
|
||||
|
||||
# Verify we got the tool call
|
||||
assert len(tool_states) == 1
|
||||
assert len(qwen3_tool_parser_parametrized.prev_tool_call_arr) == 1
|
||||
|
||||
state = tool_states[0]
|
||||
assert state["id"] is not None
|
||||
assert state["type"] == "function"
|
||||
assert state["name"] == "get_current_weather"
|
||||
|
||||
# Verify arguments were parsed correctly despite missing opening tag
|
||||
assert state["arguments"] is not None
|
||||
args = json.loads(state["arguments"])
|
||||
assert args["city"] == "Dallas"
|
||||
assert args["state"] == "TX"
|
||||
assert args["unit"] == "fahrenheit"
|
||||
495
tests/tool_parsers/test_seed_oss_tool_parser.py
Normal file
495
tests/tool_parsers/test_seed_oss_tool_parser.py
Normal file
@@ -0,0 +1,495 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: E501
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionToolsParam,
|
||||
DeltaMessage,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.tokenizers import TokenizerLike, get_tokenizer
|
||||
from vllm.tokenizers.detokenizer_utils import detokenize_incrementally
|
||||
from vllm.tool_parsers.seed_oss_tool_parser import SeedOssToolParser
|
||||
|
||||
# Use a common model that is likely to be available
|
||||
MODEL = "ByteDance-Seed/Seed-OSS-36B-Instruct"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def seed_oss_tokenizer():
|
||||
return get_tokenizer(tokenizer_name=MODEL, trust_remote_code=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def seed_oss_tool_parser(seed_oss_tokenizer):
|
||||
return SeedOssToolParser(seed_oss_tokenizer)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tools():
|
||||
return [
|
||||
ChatCompletionToolsParam(
|
||||
type="function",
|
||||
function={
|
||||
"name": "get_weather",
|
||||
"description": "Get current temperature for a given location.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "City and country e.g. Bogotá, Colombia",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description": "this is the unit of temperature",
|
||||
},
|
||||
},
|
||||
"required": ["location"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"returns": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"temperature": {
|
||||
"type": "number",
|
||||
"description": "temperature in celsius",
|
||||
}
|
||||
},
|
||||
"required": ["temperature"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"strict": True,
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def assert_tool_calls(
|
||||
actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall]
|
||||
):
|
||||
assert len(actual_tool_calls) == len(expected_tool_calls)
|
||||
|
||||
for actual_tool_call, expected_tool_call in zip(
|
||||
actual_tool_calls, expected_tool_calls
|
||||
):
|
||||
# Seed-OSS tool call will not generate id
|
||||
assert actual_tool_call.type == "function"
|
||||
assert actual_tool_call.function == expected_tool_call.function
|
||||
|
||||
assert actual_tool_call.function.name == expected_tool_call.function.name
|
||||
assert (
|
||||
actual_tool_call.function.arguments == expected_tool_call.function.arguments
|
||||
)
|
||||
|
||||
|
||||
def test_extract_tool_calls_no_tools(seed_oss_tool_parser):
|
||||
model_output = "This is a test response without any tool calls"
|
||||
extracted_tool_calls = seed_oss_tool_parser.extract_tool_calls(
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
|
||||
assert not extracted_tool_calls.tools_called
|
||||
assert extracted_tool_calls.tool_calls == []
|
||||
assert extracted_tool_calls.content == model_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"tool_call_0_thinking_budget",
|
||||
"tool_call_512_thinkg_budget",
|
||||
"tool_call_unlimited_thinking_budget",
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
(
|
||||
"""<seed:tool_call>\n<function=get_weather>\n"""
|
||||
"""<parameter=location>Barcelona, Spain</parameter>\n</function>\n</seed:tool_call>""",
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"location": "Barcelona, Spain",
|
||||
},
|
||||
),
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
"""<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """
|
||||
"""question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """
|
||||
"""there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """
|
||||
"""check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """
|
||||
"""optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """
|
||||
"""country). \n<seed:cot_budget_reflect>I have used 131 tokens, and there are 381 tokens remaining for use."""
|
||||
"""</seed:cot_budget_reflect>\n Since the unit isn\'t specified, the function will default to Celsius, which """
|
||||
"""is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """
|
||||
"""the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """
|
||||
"""user\'s input has a space, but the function might accept either; to be safe, using the standard format """
|
||||
"""with a comma).\n<seed:cot_budget_reflect>I have used 257 tokens, and there are 255 tokens remaining for """
|
||||
"""use.</seed:cot_budget_reflect>\n The unit parameter can be omitted since it\'s optional.</seed:think>\n"""
|
||||
"""<seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, Spain</parameter>\n</function>"""
|
||||
"""\n</seed:tool_call>""",
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"location": "Barcelona, Spain",
|
||||
},
|
||||
),
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
],
|
||||
"""<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """
|
||||
"""question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """
|
||||
"""there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """
|
||||
"""check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """
|
||||
"""optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """
|
||||
"""country). \n<seed:cot_budget_reflect>I have used 131 tokens, and there are 381 tokens remaining for use."""
|
||||
"""</seed:cot_budget_reflect>\n Since the unit isn\'t specified, the function will default to Celsius, which """
|
||||
"""is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """
|
||||
"""the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """
|
||||
"""user\'s input has a space, but the function might accept either; to be safe, using the standard format """
|
||||
"""with a comma).\n<seed:cot_budget_reflect>I have used 257 tokens, and there are 255 tokens remaining for """
|
||||
"""use.</seed:cot_budget_reflect>\n The unit parameter can be omitted since it\'s optional.</seed:think>\n""",
|
||||
),
|
||||
(
|
||||
"""<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """
|
||||
"""First, I need to remember the function I can use: get_weather. The function requires a """
|
||||
"""location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """
|
||||
"""the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """
|
||||
"""let me check the function docstring again. Oh, the function says unit is optional, and """
|
||||
"""returns temperature in Celsius. So I should call get_weather with location "Barcelona, """
|
||||
"""Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """
|
||||
"""The format is <seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, """
|
||||
"""Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>. """
|
||||
"""Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """
|
||||
"""of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """
|
||||
"""it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """
|
||||
"""call should be as above. Then wait for the result to come back and tell the user the """
|
||||
"""temperature in Celsius.</seed:think><seed:tool_call>\n<function=get_weather>\n<parameter=location>"""
|
||||
"""Barcelona, Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>""",
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"location": "Barcelona, Spain",
|
||||
"unit": "celsius",
|
||||
},
|
||||
),
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
],
|
||||
"""<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """
|
||||
"""First, I need to remember the function I can use: get_weather. The function requires a """
|
||||
"""location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """
|
||||
"""the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """
|
||||
"""let me check the function docstring again. Oh, the function says unit is optional, and """
|
||||
"""returns temperature in Celsius. So I should call get_weather with location "Barcelona, """
|
||||
"""Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """
|
||||
"""The format is <seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, """
|
||||
"""Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>. """
|
||||
"""Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """
|
||||
"""of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """
|
||||
"""it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """
|
||||
"""call should be as above. Then wait for the result to come back and tell the user the """
|
||||
"""temperature in Celsius.</seed:think>""",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls(
|
||||
seed_oss_tool_parser,
|
||||
sample_tools,
|
||||
model_output,
|
||||
expected_tool_calls,
|
||||
expected_content,
|
||||
):
|
||||
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
|
||||
extracted_tool_calls = seed_oss_tool_parser.extract_tool_calls(
|
||||
model_output, request=request
|
||||
) # type: ignore[arg-type]
|
||||
assert extracted_tool_calls.tools_called
|
||||
|
||||
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
|
||||
|
||||
assert extracted_tool_calls.content == expected_content
|
||||
|
||||
|
||||
def test_streaming_tool_calls_no_tools(seed_oss_tool_parser):
|
||||
model_output = "This is a test response without any tool calls"
|
||||
|
||||
result = seed_oss_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="his is a test response",
|
||||
current_text=model_output,
|
||||
delta_text=" without any tool calls.",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=None,
|
||||
)
|
||||
|
||||
# Should return the delta text as content
|
||||
assert result is not None
|
||||
assert hasattr(result, "content")
|
||||
assert result.content == " without any tool calls."
|
||||
|
||||
|
||||
def stream_delta_message_generator(
|
||||
seed_oss_tool_parser: SeedOssToolParser,
|
||||
seed_oss_tokenizer: TokenizerLike,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest | None = None,
|
||||
) -> Generator[DeltaMessage, None, None]:
|
||||
all_token_ids = seed_oss_tokenizer.encode(model_output, add_special_tokens=False)
|
||||
|
||||
previous_text = ""
|
||||
previous_tokens = None
|
||||
prefix_offset = 0
|
||||
read_offset = 0
|
||||
for i, delta_token in enumerate(all_token_ids):
|
||||
delta_token_ids = [delta_token]
|
||||
previous_token_ids = all_token_ids[:i]
|
||||
current_token_ids = all_token_ids[: i + 1]
|
||||
|
||||
(new_tokens, delta_text, new_prefix_offset, new_read_offset) = (
|
||||
detokenize_incrementally(
|
||||
tokenizer=seed_oss_tokenizer,
|
||||
all_input_ids=current_token_ids,
|
||||
prev_tokens=previous_tokens,
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
skip_special_tokens=False,
|
||||
spaces_between_special_tokens=True,
|
||||
)
|
||||
)
|
||||
|
||||
current_text = previous_text + delta_text
|
||||
|
||||
delta_message = seed_oss_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text,
|
||||
current_text,
|
||||
delta_text,
|
||||
previous_token_ids,
|
||||
current_token_ids,
|
||||
delta_token_ids,
|
||||
request=request,
|
||||
)
|
||||
if delta_message:
|
||||
yield delta_message
|
||||
|
||||
previous_text = current_text
|
||||
previous_tokens = (
|
||||
previous_tokens + new_tokens if previous_tokens else new_tokens
|
||||
)
|
||||
prefix_offset = new_prefix_offset
|
||||
read_offset = new_read_offset
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"tool_call_0_thinking_budget",
|
||||
"tool_call_512_thinkg_budget",
|
||||
"tool_call_unlimited_thinking_budget",
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
(
|
||||
"""<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n"""
|
||||
"""The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n"""
|
||||
"""<seed:tool_call>\n<function=get_weather>\n"""
|
||||
"""<parameter=location>Barcelona, Spain</parameter>\n</function>\n</seed:tool_call>""",
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"location": "Barcelona, Spain",
|
||||
},
|
||||
),
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
],
|
||||
"""<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n"""
|
||||
"""The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n""",
|
||||
),
|
||||
(
|
||||
"""<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """
|
||||
"""question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """
|
||||
"""there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """
|
||||
"""check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """
|
||||
"""optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """
|
||||
"""country). \n<seed:cot_budget_reflect>I have used 131 tokens, and there are 381 tokens remaining for use."""
|
||||
"""</seed:cot_budget_reflect>\n Since the unit isn\'t specified, the function will default to Celsius, which """
|
||||
"""is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """
|
||||
"""the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """
|
||||
"""user\'s input has a space, but the function might accept either; to be safe, using the standard format """
|
||||
"""with a comma).\n<seed:cot_budget_reflect>I have used 257 tokens, and there are 255 tokens remaining for """
|
||||
"""use.</seed:cot_budget_reflect>\n The unit parameter can be omitted since it\'s optional.</seed:think>\n"""
|
||||
"""<seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, Spain</parameter>\n</function>"""
|
||||
"""\n</seed:tool_call>""",
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"location": "Barcelona, Spain",
|
||||
},
|
||||
),
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
],
|
||||
"""<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """
|
||||
"""question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """
|
||||
"""there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """
|
||||
"""check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """
|
||||
"""optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """
|
||||
"""country). \n<seed:cot_budget_reflect>I have used 131 tokens, and there are 381 tokens remaining for use."""
|
||||
"""</seed:cot_budget_reflect>\n Since the unit isn\'t specified, the function will default to Celsius, which """
|
||||
"""is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """
|
||||
"""the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """
|
||||
"""user\'s input has a space, but the function might accept either; to be safe, using the standard format """
|
||||
"""with a comma).\n<seed:cot_budget_reflect>I have used 257 tokens, and there are 255 tokens remaining for """
|
||||
"""use.</seed:cot_budget_reflect>\n The unit parameter can be omitted since it\'s optional.</seed:think>\n""",
|
||||
),
|
||||
(
|
||||
"""<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """
|
||||
"""First, I need to remember the function I can use: get_weather. The function requires a """
|
||||
"""location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """
|
||||
"""the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """
|
||||
"""let me check the function docstring again. Oh, the function says unit is optional, and """
|
||||
"""returns temperature in Celsius. So I should call get_weather with location "Barcelona, """
|
||||
"""Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """
|
||||
"""The format is <seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, """
|
||||
"""Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>. """
|
||||
"""Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """
|
||||
"""of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """
|
||||
"""it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """
|
||||
"""call should be as above. Then wait for the result to come back and tell the user the """
|
||||
"""temperature in Celsius.</seed:think><seed:tool_call>\n<function=get_weather>\n<parameter=location>"""
|
||||
"""Barcelona, Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>""",
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"location": "Barcelona, Spain",
|
||||
"unit": "celsius",
|
||||
},
|
||||
),
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
],
|
||||
"""<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """
|
||||
"""First, I need to remember the function I can use: get_weather. The function requires a """
|
||||
"""location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """
|
||||
"""the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """
|
||||
"""let me check the function docstring again. Oh, the function says unit is optional, and """
|
||||
"""returns temperature in Celsius. So I should call get_weather with location "Barcelona, """
|
||||
"""Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """
|
||||
"""The format is <seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, """
|
||||
"""Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>. """
|
||||
"""Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """
|
||||
"""of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """
|
||||
"""it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """
|
||||
"""call should be as above. Then wait for the result to come back and tell the user the """
|
||||
"""temperature in Celsius.</seed:think>""",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_streaming_tool_calls(
|
||||
seed_oss_tool_parser,
|
||||
seed_oss_tokenizer,
|
||||
sample_tools,
|
||||
model_output,
|
||||
expected_tool_calls,
|
||||
expected_content,
|
||||
):
|
||||
"""Test incremental streaming behavior"""
|
||||
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
|
||||
|
||||
other_content = ""
|
||||
tool_states = {} # Track state per tool index
|
||||
|
||||
for delta_message in stream_delta_message_generator(
|
||||
seed_oss_tool_parser, seed_oss_tokenizer, model_output, request
|
||||
):
|
||||
# role should never be streamed from tool parser
|
||||
assert not delta_message.role
|
||||
|
||||
if delta_message.content:
|
||||
other_content += delta_message.content
|
||||
|
||||
if delta_message.tool_calls:
|
||||
for tool_call in delta_message.tool_calls:
|
||||
idx = tool_call.index
|
||||
|
||||
# Initialize state for new tool
|
||||
if idx not in tool_states:
|
||||
tool_states[idx] = {
|
||||
"id": None,
|
||||
"name": None,
|
||||
"arguments": "",
|
||||
"type": None,
|
||||
}
|
||||
|
||||
# First chunk should have id, name, and type
|
||||
if tool_call.id:
|
||||
tool_states[idx]["id"] = tool_call.id
|
||||
|
||||
if tool_call.type:
|
||||
assert tool_call.type == "function"
|
||||
tool_states[idx]["type"] = tool_call.type
|
||||
|
||||
if tool_call.function:
|
||||
if tool_call.function.name:
|
||||
# Should only be set once
|
||||
assert tool_states[idx]["name"] is None
|
||||
tool_states[idx]["name"] = tool_call.function.name
|
||||
|
||||
if tool_call.function.arguments is not None:
|
||||
# Accumulate arguments incrementally
|
||||
tool_states[idx]["arguments"] += tool_call.function.arguments
|
||||
|
||||
# Verify final content
|
||||
assert other_content == expected_content
|
||||
|
||||
# Verify we got all expected tool calls
|
||||
assert len(tool_states) == len(expected_tool_calls)
|
||||
|
||||
# Verify each tool call
|
||||
for idx, expected_tool in enumerate(expected_tool_calls):
|
||||
state = tool_states[idx]
|
||||
assert state["id"] is not None
|
||||
assert state["type"] == "function"
|
||||
assert state["name"] == expected_tool.function.name
|
||||
|
||||
# Parse accumulated arguments
|
||||
arguments_str = state["arguments"]
|
||||
assert arguments_str is not None
|
||||
actual_args = json.loads(arguments_str)
|
||||
expected_args = json.loads(expected_tool.function.arguments)
|
||||
assert actual_args == expected_args
|
||||
534
tests/tool_parsers/test_xlam_tool_parser.py
Normal file
534
tests/tool_parsers/test_xlam_tool_parser.py
Normal file
@@ -0,0 +1,534 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaMessage,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.tokenizers import TokenizerLike, get_tokenizer
|
||||
from vllm.tokenizers.detokenizer_utils import detokenize_incrementally
|
||||
from vllm.tool_parsers.xlam_tool_parser import xLAMToolParser
|
||||
|
||||
# Use a common model that is likely to be available
|
||||
MODEL = "Salesforce/Llama-xLAM-2-8B-fc-r"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def xlam_tokenizer():
|
||||
return get_tokenizer(tokenizer_name=MODEL)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def xlam_tool_parser(xlam_tokenizer):
|
||||
return xLAMToolParser(xlam_tokenizer)
|
||||
|
||||
|
||||
def assert_tool_calls(
|
||||
actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall]
|
||||
):
|
||||
assert len(actual_tool_calls) == len(expected_tool_calls)
|
||||
|
||||
for actual_tool_call, expected_tool_call in zip(
|
||||
actual_tool_calls, expected_tool_calls
|
||||
):
|
||||
assert isinstance(actual_tool_call.id, str)
|
||||
assert len(actual_tool_call.id) > 16
|
||||
|
||||
assert actual_tool_call.type == "function"
|
||||
assert actual_tool_call.function == expected_tool_call.function
|
||||
|
||||
|
||||
def stream_delta_message_generator(
|
||||
xlam_tool_parser: xLAMToolParser,
|
||||
xlam_tokenizer: TokenizerLike,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest | None = None,
|
||||
) -> Generator[DeltaMessage, None, None]:
|
||||
all_token_ids = xlam_tokenizer.encode(model_output, add_special_tokens=False)
|
||||
|
||||
previous_text = ""
|
||||
previous_tokens = None
|
||||
prefix_offset = 0
|
||||
read_offset = 0
|
||||
for i, delta_token in enumerate(all_token_ids):
|
||||
delta_token_ids = [delta_token]
|
||||
previous_token_ids = all_token_ids[:i]
|
||||
current_token_ids = all_token_ids[: i + 1]
|
||||
|
||||
(new_tokens, delta_text, new_prefix_offset, new_read_offset) = (
|
||||
detokenize_incrementally(
|
||||
tokenizer=xlam_tokenizer,
|
||||
all_input_ids=current_token_ids,
|
||||
prev_tokens=previous_tokens,
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
skip_special_tokens=False,
|
||||
spaces_between_special_tokens=True,
|
||||
)
|
||||
)
|
||||
|
||||
current_text = previous_text + delta_text
|
||||
|
||||
delta_message = xlam_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text,
|
||||
current_text,
|
||||
delta_text,
|
||||
previous_token_ids,
|
||||
current_token_ids,
|
||||
delta_token_ids,
|
||||
request=request,
|
||||
)
|
||||
if delta_message:
|
||||
yield delta_message
|
||||
|
||||
previous_text = current_text
|
||||
previous_tokens = (
|
||||
previous_tokens + new_tokens if previous_tokens else new_tokens
|
||||
)
|
||||
prefix_offset = new_prefix_offset
|
||||
read_offset = new_read_offset
|
||||
|
||||
|
||||
def test_extract_tool_calls_no_tools(xlam_tool_parser):
|
||||
model_output = "This is a test"
|
||||
extracted_tool_calls = xlam_tool_parser.extract_tool_calls(
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
assert not extracted_tool_calls.tools_called
|
||||
assert extracted_tool_calls.tool_calls == []
|
||||
assert extracted_tool_calls.content == model_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"parallel_tool_calls",
|
||||
"single_tool_with_think_tag",
|
||||
"single_tool_with_json_code_block",
|
||||
"single_tool_with_tool_calls_tag",
|
||||
"single_tool_with_tool_call_xml_tags",
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
(
|
||||
"""[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}, {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}
|
||||
),
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Orlando",
|
||||
"state": "FL",
|
||||
"unit": "fahrenheit",
|
||||
}
|
||||
),
|
||||
)
|
||||
),
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
"""<think>I'll help you with that.</think>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"<think>I'll help you with that.</think>",
|
||||
),
|
||||
(
|
||||
"""I'll help you with that.\n```json\n[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]\n```""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"I'll help you with that.",
|
||||
),
|
||||
(
|
||||
"""I'll check the weather for you.[TOOL_CALLS][{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"I'll check the weather for you.",
|
||||
),
|
||||
(
|
||||
"""I'll help you check the weather.<tool_call>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]</tool_call>""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"I'll help you check the weather.",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls(
|
||||
xlam_tool_parser, model_output, expected_tool_calls, expected_content
|
||||
):
|
||||
extracted_tool_calls = xlam_tool_parser.extract_tool_calls(
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
assert extracted_tool_calls.tools_called
|
||||
|
||||
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
|
||||
|
||||
assert extracted_tool_calls.content == expected_content
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=["list_structured_tool_call"],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
(
|
||||
"""[{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Seattle",
|
||||
"state": "WA",
|
||||
"unit": "celsius",
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls_list_structure(
|
||||
xlam_tool_parser, model_output, expected_tool_calls, expected_content
|
||||
):
|
||||
"""Test extraction of tool calls when the model outputs a list-structured tool call.""" # noqa: E501
|
||||
extracted_tool_calls = xlam_tool_parser.extract_tool_calls(
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
assert extracted_tool_calls.tools_called
|
||||
|
||||
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
|
||||
|
||||
assert extracted_tool_calls.content == expected_content
|
||||
|
||||
|
||||
# Test for preprocess_model_output method
|
||||
def test_preprocess_model_output(xlam_tool_parser):
|
||||
# Test with list structure
|
||||
model_output = (
|
||||
"""[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]""" # noqa: E501
|
||||
)
|
||||
content, potential_tool_calls = xlam_tool_parser.preprocess_model_output(
|
||||
model_output
|
||||
)
|
||||
assert content is None
|
||||
assert potential_tool_calls == model_output
|
||||
|
||||
# Test with thinking tag
|
||||
model_output = """<think>I'll help you with that.</think>[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]""" # noqa: E501
|
||||
content, potential_tool_calls = xlam_tool_parser.preprocess_model_output(
|
||||
model_output
|
||||
)
|
||||
assert content == "<think>I'll help you with that.</think>"
|
||||
assert (
|
||||
potential_tool_calls
|
||||
== '[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]'
|
||||
)
|
||||
|
||||
# Test with JSON code block
|
||||
model_output = """I'll help you with that.
|
||||
```json
|
||||
[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]
|
||||
```"""
|
||||
content, potential_tool_calls = xlam_tool_parser.preprocess_model_output(
|
||||
model_output
|
||||
)
|
||||
assert content == "I'll help you with that."
|
||||
assert "get_current_weather" in potential_tool_calls
|
||||
|
||||
# Test with no tool calls
|
||||
model_output = """I'll help you with that."""
|
||||
content, potential_tool_calls = xlam_tool_parser.preprocess_model_output(
|
||||
model_output
|
||||
)
|
||||
assert content == model_output
|
||||
assert potential_tool_calls is None
|
||||
|
||||
|
||||
# Simulate streaming to test extract_tool_calls_streaming
|
||||
def test_streaming_with_list_structure(xlam_tool_parser):
|
||||
# Reset streaming state
|
||||
xlam_tool_parser.prev_tool_calls = []
|
||||
xlam_tool_parser.current_tools_sent = []
|
||||
xlam_tool_parser.streamed_args = []
|
||||
xlam_tool_parser.current_tool_id = -1
|
||||
|
||||
# Simulate receiving a message with list structure
|
||||
current_text = (
|
||||
"""[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]""" # noqa: E501
|
||||
)
|
||||
|
||||
# First call to set up the tool
|
||||
xlam_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="",
|
||||
current_text=current_text,
|
||||
delta_text="]",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=None,
|
||||
)
|
||||
|
||||
# Make sure the tool is set up correctly
|
||||
assert xlam_tool_parser.current_tool_id >= 0, "Tool index should be initialized"
|
||||
|
||||
# Manually set up the state for sending the tool name
|
||||
xlam_tool_parser.current_tools_sent = [False]
|
||||
|
||||
# Call to send the function name
|
||||
result = xlam_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text=current_text,
|
||||
current_text=current_text,
|
||||
delta_text="",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=None,
|
||||
)
|
||||
|
||||
# Check that we get a result with the proper tool call
|
||||
if result is not None:
|
||||
assert hasattr(result, "tool_calls")
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].function.name == "get_current_weather"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"parallel_tool_calls",
|
||||
"single_tool_with_think_tag",
|
||||
"single_tool_with_json_code_block",
|
||||
"single_tool_with_tool_calls_tag",
|
||||
"single_tool_with_tool_call_xml_tags",
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
(
|
||||
"""[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}, {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}
|
||||
),
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Orlando",
|
||||
"state": "FL",
|
||||
"unit": "fahrenheit",
|
||||
}
|
||||
),
|
||||
)
|
||||
),
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""<think>I'll help you with that.</think>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"<think>I'll help you with that.</think>",
|
||||
),
|
||||
(
|
||||
"""```json\n[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]\n```""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS][{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""I can help with that.<tool_call>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]</tool_call>""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"I can help with that.",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls_streaming_incremental(
|
||||
xlam_tool_parser,
|
||||
xlam_tokenizer,
|
||||
model_output,
|
||||
expected_tool_calls,
|
||||
expected_content,
|
||||
):
|
||||
"""Verify the XLAM Parser streaming behavior by verifying each chunk is as expected.""" # noqa: E501
|
||||
request = ChatCompletionRequest(model=MODEL, messages=[], tools=[])
|
||||
|
||||
chunks = []
|
||||
for delta_message in stream_delta_message_generator(
|
||||
xlam_tool_parser, xlam_tokenizer, model_output, request
|
||||
):
|
||||
chunks.append(delta_message)
|
||||
|
||||
# Should have multiple chunks
|
||||
assert len(chunks) >= 3
|
||||
|
||||
# Should have a chunk with tool header (id, name, type) for the first tool call # noqa: E501
|
||||
header_found = False
|
||||
expected_first_tool = expected_tool_calls[0]
|
||||
for chunk in chunks:
|
||||
if chunk.tool_calls and chunk.tool_calls[0].id:
|
||||
header_found = True
|
||||
assert (
|
||||
chunk.tool_calls[0].function.name == expected_first_tool.function.name
|
||||
)
|
||||
assert chunk.tool_calls[0].type == "function"
|
||||
# Arguments may be empty initially or None
|
||||
if chunk.tool_calls[0].function.arguments is not None:
|
||||
# If present, should be empty string initially
|
||||
assert chunk.tool_calls[0].function.arguments == ""
|
||||
break
|
||||
assert header_found
|
||||
|
||||
# Should have chunks with incremental arguments
|
||||
arg_chunks = []
|
||||
for chunk in chunks:
|
||||
if (
|
||||
chunk.tool_calls
|
||||
and chunk.tool_calls[0].function.arguments
|
||||
and chunk.tool_calls[0].function.arguments != ""
|
||||
and chunk.tool_calls[0].index
|
||||
== 0 # Only collect arguments from the first tool call
|
||||
):
|
||||
arg_chunks.append(chunk.tool_calls[0].function.arguments)
|
||||
|
||||
# Arguments should be streamed incrementally
|
||||
assert len(arg_chunks) > 1
|
||||
|
||||
# Concatenated arguments should form valid JSON for the first tool call
|
||||
full_args = "".join(arg_chunks)
|
||||
parsed_args = json.loads(full_args)
|
||||
expected_args = json.loads(expected_first_tool.function.arguments)
|
||||
assert parsed_args == expected_args
|
||||
Reference in New Issue
Block a user