Sync from v0.13
This commit is contained in:
0
tests/entrypoints/openai/tool_parsers/__init__.py
Normal file
0
tests/entrypoints/openai/tool_parsers/__init__.py
Normal file
12
tests/entrypoints/openai/tool_parsers/conftest.py
Normal file
12
tests/entrypoints/openai/tool_parsers/conftest.py
Normal file
@@ -0,0 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def default_tokenizer() -> TokenizerLike:
|
||||
return AutoTokenizer.from_pretrained("gpt2")
|
||||
@@ -0,0 +1,176 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.entrypoints.openai.tool_parsers.utils import (
|
||||
run_tool_extraction,
|
||||
run_tool_extraction_streaming,
|
||||
)
|
||||
from vllm.entrypoints.openai.protocol import FunctionCall
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers import ToolParser, ToolParserManager
|
||||
|
||||
SIMPLE_ARGS_DICT = {
|
||||
"action": "create",
|
||||
"id": "preferences",
|
||||
}
|
||||
SIMPLE_FUNCTION_JSON = json.dumps(
|
||||
{
|
||||
"name": "manage_user_memory",
|
||||
"arguments": SIMPLE_ARGS_DICT,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
SIMPLE_FUNCTION_OUTPUT = "function call" + SIMPLE_FUNCTION_JSON
|
||||
SIMPLE_FUNCTION_CALL = FunctionCall(
|
||||
name="manage_user_memory",
|
||||
arguments=json.dumps(SIMPLE_ARGS_DICT, ensure_ascii=False),
|
||||
)
|
||||
|
||||
|
||||
PARAMETERLESS_FUNCTION_JSON = json.dumps(
|
||||
{
|
||||
"name": "manage_user_memory",
|
||||
"arguments": {},
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
PARAMETERLESS_FUNCTION_OUTPUT = "function call" + PARAMETERLESS_FUNCTION_JSON
|
||||
PARAMETERLESS_FUNCTION_CALL = FunctionCall(
|
||||
name="manage_user_memory",
|
||||
arguments=json.dumps({}, ensure_ascii=False),
|
||||
)
|
||||
|
||||
|
||||
COMPLEX_ARGS_DICT = {
|
||||
"action": "create",
|
||||
"id": "preferences",
|
||||
"content": {
|
||||
"short_answers": True,
|
||||
"hate_emojis": True,
|
||||
"english_ui": False,
|
||||
"russian_math_explanations": True,
|
||||
},
|
||||
}
|
||||
COMPLEX_FUNCTION_JSON = json.dumps(
|
||||
{
|
||||
"name": "manage_user_memory",
|
||||
"arguments": COMPLEX_ARGS_DICT,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
COMPLEX_FUNCTION_OUTPUT = "function call" + COMPLEX_FUNCTION_JSON
|
||||
COMPLEX_FUNCTION_CALL = FunctionCall(
|
||||
name="manage_user_memory",
|
||||
arguments=json.dumps(COMPLEX_ARGS_DICT, ensure_ascii=False),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [True, False])
|
||||
def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike):
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")(
|
||||
default_tokenizer
|
||||
)
|
||||
model_output = "How can I help you today?"
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, model_output, streaming=streaming
|
||||
)
|
||||
assert content == model_output
|
||||
assert len(tool_calls) == 0
|
||||
|
||||
|
||||
TEST_CASES = [
|
||||
pytest.param(
|
||||
True,
|
||||
SIMPLE_FUNCTION_OUTPUT,
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
None,
|
||||
id="simple_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
SIMPLE_FUNCTION_OUTPUT,
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
None,
|
||||
id="simple_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
PARAMETERLESS_FUNCTION_OUTPUT,
|
||||
[PARAMETERLESS_FUNCTION_CALL],
|
||||
None,
|
||||
id="parameterless_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
PARAMETERLESS_FUNCTION_OUTPUT,
|
||||
[PARAMETERLESS_FUNCTION_CALL],
|
||||
None,
|
||||
id="parameterless_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
COMPLEX_FUNCTION_OUTPUT,
|
||||
[COMPLEX_FUNCTION_CALL],
|
||||
None,
|
||||
id="complex_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
COMPLEX_FUNCTION_OUTPUT,
|
||||
[COMPLEX_FUNCTION_CALL],
|
||||
None,
|
||||
id="complex_nonstreaming",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"streaming, model_output, expected_tool_calls, expected_content", TEST_CASES
|
||||
)
|
||||
def test_tool_call(
|
||||
streaming: bool,
|
||||
model_output: str,
|
||||
expected_tool_calls: list[FunctionCall],
|
||||
expected_content: str | None,
|
||||
default_tokenizer: TokenizerLike,
|
||||
):
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")(
|
||||
default_tokenizer
|
||||
)
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, model_output, streaming=streaming
|
||||
)
|
||||
assert content == expected_content
|
||||
assert len(tool_calls) == len(expected_tool_calls)
|
||||
for actual, expected in zip(tool_calls, expected_tool_calls):
|
||||
assert actual.type == "function"
|
||||
assert actual.function.name == expected.name
|
||||
actual_args = json.loads(actual.function.arguments)
|
||||
expected_args = json.loads(expected.arguments)
|
||||
assert actual_args == expected_args
|
||||
|
||||
|
||||
def test_streaming_tool_call_with_large_steps(default_tokenizer: TokenizerLike):
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")(
|
||||
default_tokenizer
|
||||
)
|
||||
model_output_deltas = [
|
||||
"function call",
|
||||
COMPLEX_FUNCTION_JSON[:40],
|
||||
COMPLEX_FUNCTION_JSON[40:],
|
||||
]
|
||||
reconstructor = run_tool_extraction_streaming(
|
||||
tool_parser,
|
||||
model_output_deltas,
|
||||
assert_one_tool_per_delta=False,
|
||||
)
|
||||
assert len(reconstructor.tool_calls) == 1
|
||||
call = reconstructor.tool_calls[0]
|
||||
assert call.type == "function"
|
||||
assert call.function.name == "manage_user_memory"
|
||||
args_dict = json.loads(call.function.arguments)
|
||||
assert args_dict == COMPLEX_ARGS_DICT
|
||||
460
tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py
Normal file
460
tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py
Normal file
@@ -0,0 +1,460 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers.hermes_tool_parser import Hermes2ProToolParser
|
||||
|
||||
from ....utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
LORA_MODEL = "minpeter/LoRA-Llama-3.2-1B-tool-vllm-ci"
|
||||
|
||||
SERVER_ARGS = [
|
||||
"--enforce-eager",
|
||||
"--enable-auto-tool-choice",
|
||||
"--tool-call-parser",
|
||||
"hermes",
|
||||
"--enable-lora",
|
||||
"--lora-modules",
|
||||
f"{LORA_MODEL}={LORA_MODEL}",
|
||||
"--tokenizer",
|
||||
f"{LORA_MODEL}",
|
||||
]
|
||||
|
||||
TOOLS = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
PRODUCT_TOOLS = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_product_info",
|
||||
"description": "Get detailed information of a product based on its "
|
||||
"product ID.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"inserted": {
|
||||
"type": "boolean",
|
||||
"description": "inserted.",
|
||||
},
|
||||
"product_id": {
|
||||
"type": "integer",
|
||||
"description": "The product ID of the product.",
|
||||
},
|
||||
},
|
||||
"required": ["product_id", "inserted"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
MESSAGES = [{"role": "user", "content": "What's the weather like in Boston?"}]
|
||||
|
||||
PRODUCT_MESSAGES = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hi! Do you have any detailed information about the product id "
|
||||
"7355608 and inserted true?",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_streaming_tool_call():
|
||||
"""Test tool call in non-streaming mode."""
|
||||
with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server:
|
||||
client = server.get_async_client()
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model=LORA_MODEL,
|
||||
messages=MESSAGES,
|
||||
tools=TOOLS,
|
||||
tool_choice="auto",
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
assert response.choices
|
||||
choice = response.choices[0]
|
||||
message = choice.message
|
||||
|
||||
assert choice.finish_reason == "tool_calls"
|
||||
assert message.tool_calls is not None
|
||||
|
||||
tool_call = message.tool_calls[0]
|
||||
assert tool_call.type == "function"
|
||||
assert tool_call.function.name == "get_current_weather"
|
||||
|
||||
arguments = json.loads(tool_call.function.arguments)
|
||||
assert "location" in arguments
|
||||
assert "Boston" in arguments["location"]
|
||||
print("\n[Non-Streaming Test Passed]")
|
||||
print(f"Tool Call: {tool_call.function.name}")
|
||||
print(f"Arguments: {arguments}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_tool_call():
|
||||
"""Test tool call in streaming mode."""
|
||||
with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server:
|
||||
client = server.get_async_client()
|
||||
|
||||
stream = await client.chat.completions.create(
|
||||
model=LORA_MODEL,
|
||||
messages=MESSAGES,
|
||||
tools=TOOLS,
|
||||
tool_choice="auto",
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
tool_call_chunks = {}
|
||||
async for chunk in stream:
|
||||
if not chunk.choices:
|
||||
continue
|
||||
|
||||
delta = chunk.choices[0].delta
|
||||
if not delta or not delta.tool_calls:
|
||||
continue
|
||||
|
||||
for tool_chunk in delta.tool_calls:
|
||||
index = tool_chunk.index
|
||||
if index not in tool_call_chunks:
|
||||
tool_call_chunks[index] = {"name": "", "arguments": ""}
|
||||
|
||||
if tool_chunk.function.name:
|
||||
tool_call_chunks[index]["name"] += tool_chunk.function.name
|
||||
if tool_chunk.function.arguments:
|
||||
tool_call_chunks[index]["arguments"] += (
|
||||
tool_chunk.function.arguments
|
||||
)
|
||||
|
||||
assert len(tool_call_chunks) == 1
|
||||
reconstructed_tool_call = tool_call_chunks[0]
|
||||
|
||||
assert reconstructed_tool_call["name"] == "get_current_weather"
|
||||
|
||||
arguments = json.loads(reconstructed_tool_call["arguments"])
|
||||
assert "location" in arguments
|
||||
assert "Boston" in arguments["location"]
|
||||
print("\n[Streaming Test Passed]")
|
||||
print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}")
|
||||
print(f"Reconstructed Arguments: {arguments}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_streaming_product_tool_call():
|
||||
"""Test tool call integer and boolean parameters in non-streaming mode."""
|
||||
with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server:
|
||||
client = server.get_async_client()
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model=LORA_MODEL,
|
||||
messages=PRODUCT_MESSAGES,
|
||||
tools=PRODUCT_TOOLS,
|
||||
tool_choice="auto",
|
||||
temperature=0.66,
|
||||
)
|
||||
|
||||
assert response.choices
|
||||
choice = response.choices[0]
|
||||
message = choice.message
|
||||
|
||||
assert choice.finish_reason == "tool_calls"
|
||||
assert message.tool_calls is not None
|
||||
|
||||
tool_call = message.tool_calls[0]
|
||||
assert tool_call.type == "function"
|
||||
assert tool_call.function.name == "get_product_info"
|
||||
|
||||
arguments = json.loads(tool_call.function.arguments)
|
||||
assert "product_id" in arguments
|
||||
assert "inserted" in arguments
|
||||
|
||||
product_id = arguments.get("product_id")
|
||||
inserted = arguments.get("inserted")
|
||||
|
||||
assert isinstance(product_id, int)
|
||||
assert product_id == 7355608
|
||||
assert isinstance(inserted, bool)
|
||||
assert inserted is True
|
||||
|
||||
print("\n[Non-Streaming Product Test Passed]")
|
||||
print(f"Tool Call: {tool_call.function.name}")
|
||||
print(f"Arguments: {arguments}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_product_tool_call():
|
||||
"""Test tool call integer and boolean parameters in streaming mode."""
|
||||
with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server:
|
||||
client = server.get_async_client()
|
||||
|
||||
stream = await client.chat.completions.create(
|
||||
model=LORA_MODEL,
|
||||
messages=PRODUCT_MESSAGES,
|
||||
tools=PRODUCT_TOOLS,
|
||||
tool_choice="auto",
|
||||
temperature=0.66,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
tool_call_chunks = {}
|
||||
async for chunk in stream:
|
||||
if not chunk.choices:
|
||||
continue
|
||||
|
||||
delta = chunk.choices[0].delta
|
||||
if not delta or not delta.tool_calls:
|
||||
continue
|
||||
|
||||
for tool_chunk in delta.tool_calls:
|
||||
index = tool_chunk.index
|
||||
if index not in tool_call_chunks:
|
||||
tool_call_chunks[index] = {"name": "", "arguments": ""}
|
||||
|
||||
if tool_chunk.function.name:
|
||||
tool_call_chunks[index]["name"] += tool_chunk.function.name
|
||||
if tool_chunk.function.arguments:
|
||||
tool_call_chunks[index]["arguments"] += (
|
||||
tool_chunk.function.arguments
|
||||
)
|
||||
|
||||
assert len(tool_call_chunks) == 1
|
||||
reconstructed_tool_call = tool_call_chunks[0]
|
||||
|
||||
assert reconstructed_tool_call["name"] == "get_product_info"
|
||||
|
||||
arguments = json.loads(reconstructed_tool_call["arguments"])
|
||||
assert "product_id" in arguments
|
||||
assert "inserted" in arguments
|
||||
|
||||
# Handle type coercion for streaming test as well
|
||||
product_id = arguments.get("product_id")
|
||||
inserted = arguments.get("inserted")
|
||||
|
||||
assert isinstance(product_id, int)
|
||||
assert product_id == 7355608
|
||||
assert isinstance(inserted, bool)
|
||||
assert inserted is True
|
||||
|
||||
print("\n[Streaming Product Test Passed]")
|
||||
print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}")
|
||||
print(f"Reconstructed Arguments: {arguments}")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def qwen_tokenizer() -> TokenizerLike:
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
|
||||
return get_tokenizer("Qwen/Qwen3-32B")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hermes_parser(qwen_tokenizer: TokenizerLike) -> Hermes2ProToolParser:
|
||||
return Hermes2ProToolParser(qwen_tokenizer)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def any_chat_request() -> ChatCompletionRequest:
|
||||
return ChatCompletionRequest(
|
||||
seed=42,
|
||||
model="Qwen/Qwen3-32B",
|
||||
messages=[],
|
||||
)
|
||||
|
||||
|
||||
def test_hermes_parser_streaming_just_forward_text(
|
||||
qwen_tokenizer: TokenizerLike,
|
||||
hermes_parser: Hermes2ProToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
text = """This is some prior text that has nothing to do with tool calling."""
|
||||
tokens = qwen_tokenizer.encode(text)
|
||||
previous_text = ""
|
||||
delta_messages = []
|
||||
for token in tokens:
|
||||
delta_text = qwen_tokenizer.decode([token])
|
||||
current_text = previous_text + delta_text
|
||||
delta = hermes_parser.extract_tool_calls_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=any_chat_request,
|
||||
)
|
||||
previous_text = current_text
|
||||
delta_messages.append(delta)
|
||||
|
||||
for delta in delta_messages:
|
||||
assert delta is not None
|
||||
assert not delta.tool_calls
|
||||
|
||||
print(delta_messages)
|
||||
assert "".join([delta.content for delta in delta_messages]) == text
|
||||
|
||||
|
||||
def test_hermes_parser_streaming_failure_case_bug_19056(
|
||||
qwen_tokenizer: TokenizerLike,
|
||||
hermes_parser: Hermes2ProToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
text = """<tool_call>
|
||||
{"name": "final_answer", "arguments": {"trigger": true}}
|
||||
</tool_call>"""
|
||||
tokens = qwen_tokenizer.encode(text)
|
||||
previous_text = ""
|
||||
delta_messages = []
|
||||
for token in tokens:
|
||||
text = qwen_tokenizer.decode([token])
|
||||
current_text = previous_text + text
|
||||
delta = hermes_parser.extract_tool_calls_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=text,
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=any_chat_request,
|
||||
)
|
||||
previous_text = current_text
|
||||
if delta is not None:
|
||||
delta_messages.append(delta)
|
||||
|
||||
assert delta_messages[0].tool_calls[0].function.name == "final_answer"
|
||||
tool_call_args = "".join(
|
||||
delta.tool_calls[0].function.arguments or "" for delta in delta_messages
|
||||
)
|
||||
assert tool_call_args == '{"trigger": true}'
|
||||
|
||||
|
||||
def test_hermes_parser_streaming(
|
||||
qwen_tokenizer: TokenizerLike,
|
||||
hermes_parser: Hermes2ProToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
text = '<tool_call>\
|
||||
{"name": "get_current_temperature",\
|
||||
"arguments": {"location":\
|
||||
"San Francisco, California, United States", "unit": "celsius"}}\
|
||||
</tool_call>'
|
||||
|
||||
tokens = qwen_tokenizer.encode(text)
|
||||
previous_text = ""
|
||||
delta_messages = []
|
||||
for token in tokens:
|
||||
text = qwen_tokenizer.decode([token])
|
||||
current_text = previous_text + text
|
||||
delta = hermes_parser.extract_tool_calls_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=text,
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=any_chat_request,
|
||||
)
|
||||
previous_text = current_text
|
||||
if delta is not None:
|
||||
delta_messages.append(delta)
|
||||
print(delta_messages)
|
||||
assert delta_messages[0].tool_calls[0].function.name == "get_current_temperature"
|
||||
tool_call_args = "".join(
|
||||
delta.tool_calls[0].function.arguments or "" for delta in delta_messages
|
||||
)
|
||||
assert tool_call_args == (
|
||||
'{"location":"San Francisco, California, United States", "unit": "celsius"}'
|
||||
)
|
||||
|
||||
|
||||
def test_hermes_parser_non_streaming_no_tool_call(
|
||||
hermes_parser: Hermes2ProToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
text = """This is not a tool call."""
|
||||
tool_call = hermes_parser.extract_tool_calls(
|
||||
model_output=text,
|
||||
request=any_chat_request,
|
||||
)
|
||||
|
||||
assert tool_call is not None
|
||||
assert not tool_call.tools_called
|
||||
|
||||
|
||||
def test_hermes_parser_non_streaming_tool_call_between_tags(
|
||||
hermes_parser: Hermes2ProToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
text = """<tool_call>
|
||||
{"name": "final_answer", "arguments": {"trigger": true}}
|
||||
</tool_call>"""
|
||||
tool_call = hermes_parser.extract_tool_calls(
|
||||
model_output=text,
|
||||
request=any_chat_request,
|
||||
)
|
||||
|
||||
assert tool_call is not None
|
||||
assert tool_call.tools_called
|
||||
assert tool_call.tool_calls[0].function.name == "final_answer"
|
||||
assert tool_call.tool_calls[0].function.arguments == '{"trigger": true}'
|
||||
|
||||
|
||||
def test_hermes_parser_non_streaming_tool_call_until_eos(
|
||||
hermes_parser: Hermes2ProToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
text = """<tool_call>
|
||||
{"name": "final_answer", "arguments": {"trigger": true}}"""
|
||||
tool_call = hermes_parser.extract_tool_calls(
|
||||
model_output=text,
|
||||
request=any_chat_request,
|
||||
)
|
||||
|
||||
assert tool_call is not None
|
||||
assert tool_call.tools_called
|
||||
assert tool_call.tool_calls[0].function.name == "final_answer"
|
||||
assert tool_call.tool_calls[0].function.arguments == '{"trigger": true}'
|
||||
|
||||
|
||||
def test_hermes_parser_non_streaming_tool_call_invalid_json(
|
||||
hermes_parser: Hermes2ProToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
# Missing closing brace to trigger exception
|
||||
text = """<tool_call>
|
||||
{"name": "final_answer", "arguments": {"trigger": true}"""
|
||||
tool_call = hermes_parser.extract_tool_calls(
|
||||
model_output=text,
|
||||
request=any_chat_request,
|
||||
)
|
||||
|
||||
assert tool_call is not None
|
||||
assert not tool_call.tools_called
|
||||
@@ -0,0 +1,179 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: E501
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.entrypoints.openai.tool_parsers.utils import (
|
||||
run_tool_extraction,
|
||||
run_tool_extraction_streaming,
|
||||
)
|
||||
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
|
||||
from vllm.tool_parsers import ToolParser, ToolParserManager
|
||||
|
||||
|
||||
def make_tool_call(name, arguments):
|
||||
return ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(name=name, arguments=json.dumps(arguments)),
|
||||
)
|
||||
|
||||
|
||||
# TODO: add reason prefix and suffix.
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_output,expected_tool_calls,expected_content",
|
||||
[
|
||||
# No tool call
|
||||
("How can I help you today?", [], "How can I help you today?"),
|
||||
# Single tool call, no content
|
||||
(
|
||||
'<tool_calls>[{"name": "get_weather", "arguments": {"city": "San Francisco", "metric": "celsius"}}]</tool_calls>', # noqa: E501
|
||||
[
|
||||
make_tool_call(
|
||||
"get_weather", {"city": "San Francisco", "metric": "celsius"}
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
# Multiple tool calls
|
||||
(
|
||||
'<tool_calls>[{"name": "get_weather", "arguments": {"city": "San Francisco", "metric": "celsius"}}, {"name": "register_user", "arguments": {"name": "John Doe", "age": 37, "address": {"city": "San Francisco", "state": "CA"}, "role": null, "passed_test": true, "aliases": ["John", "Johnny"]}}]</tool_calls>', # noqa: E501
|
||||
[
|
||||
make_tool_call(
|
||||
"get_weather", {"city": "San Francisco", "metric": "celsius"}
|
||||
),
|
||||
make_tool_call(
|
||||
"register_user",
|
||||
{
|
||||
"name": "John Doe",
|
||||
"age": 37,
|
||||
"address": {"city": "San Francisco", "state": "CA"},
|
||||
"role": None,
|
||||
"passed_test": True,
|
||||
"aliases": ["John", "Johnny"],
|
||||
},
|
||||
),
|
||||
],
|
||||
None,
|
||||
),
|
||||
# Content before tool call
|
||||
(
|
||||
'I will call the tool now. <tool_calls>[{"name": "get_weather", "arguments": {"city": "Boston"}}]</tool_calls>', # noqa: E501
|
||||
[make_tool_call("get_weather", {"city": "Boston"})],
|
||||
"I will call the tool now. ",
|
||||
),
|
||||
# Content after tool call (should be stripped)
|
||||
(
|
||||
'<tool_calls>[{"name": "get_weather", "arguments": {"city": "Seattle"}}]</tool_calls>\nThank you!', # noqa: E501
|
||||
[make_tool_call("get_weather", {"city": "Seattle"})],
|
||||
None,
|
||||
),
|
||||
(
|
||||
'<tool_calls>[{"name": "complex_tool", "arguments": {"level1": {"level2": {"level3": {"value": 123}}}}}]</tool_calls>',
|
||||
[
|
||||
make_tool_call(
|
||||
"complex_tool", {"level1": {"level2": {"level3": {"value": 123}}}}
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_hunyuan_a13b_tool_parser_extract(
|
||||
model_output, expected_tool_calls, expected_content
|
||||
):
|
||||
mock_tokenizer = MagicMock()
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("hunyuan_a13b")(
|
||||
mock_tokenizer
|
||||
)
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, model_output, streaming=False
|
||||
)
|
||||
|
||||
# align the random id.
|
||||
for idx in range(len(tool_calls)):
|
||||
tool_calls[idx].id = expected_tool_calls[idx].id
|
||||
assert tool_calls == expected_tool_calls
|
||||
assert content == expected_content
|
||||
|
||||
|
||||
# Streaming test: simulate incremental output
|
||||
@pytest.mark.parametrize(
|
||||
"model_deltas,expected_tool_calls",
|
||||
[
|
||||
(
|
||||
[
|
||||
'<tool_calls>[{"name": "get_weather", ',
|
||||
'"arguments": {"city": "San Francisco", ',
|
||||
'"metric": "celsius"}}]',
|
||||
"</tool_calls>",
|
||||
],
|
||||
[
|
||||
make_tool_call(
|
||||
"get_weather", {"city": "San Francisco", "metric": "celsius"}
|
||||
)
|
||||
],
|
||||
),
|
||||
(
|
||||
[
|
||||
'<tool_calls>[{"name":',
|
||||
' "get_weather",',
|
||||
' "arguments":',
|
||||
' {"city": "Boston"}',
|
||||
"}]",
|
||||
"</tool_calls>",
|
||||
],
|
||||
[make_tool_call("get_weather", {"city": "Boston"})],
|
||||
),
|
||||
(
|
||||
[
|
||||
"",
|
||||
'<tool_calls>[{"name":',
|
||||
' "get_weather",',
|
||||
' "arguments":',
|
||||
' {"city": "Boston"}',
|
||||
"}]",
|
||||
"</tool_calls>",
|
||||
"\n</answer>",
|
||||
],
|
||||
[make_tool_call("get_weather", {"city": "Boston"})],
|
||||
),
|
||||
pytest.param(
|
||||
[
|
||||
'<tool_calls>[{"name": "complex_tool",',
|
||||
' "arguments": ',
|
||||
' {"level1": {"level2": ',
|
||||
'{"level3": {"value": 123}}}}}',
|
||||
"]</tool_calls>",
|
||||
],
|
||||
[
|
||||
make_tool_call(
|
||||
"complex_tool", {"level1": {"level2": {"level3": {"value": 123}}}}
|
||||
)
|
||||
],
|
||||
marks=pytest.mark.xfail(
|
||||
reason="stream parsing not support nested json yet."
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_hunyuan_a13b_tool_parser_streaming(model_deltas, expected_tool_calls):
|
||||
mock_tokenizer = MagicMock()
|
||||
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("hunyuan_a13b")(
|
||||
mock_tokenizer
|
||||
)
|
||||
reconstructor = run_tool_extraction_streaming(
|
||||
tool_parser, model_deltas, assert_one_tool_per_delta=False
|
||||
)
|
||||
|
||||
# align the random id.
|
||||
for idx in range(len(reconstructor.tool_calls)):
|
||||
reconstructor.tool_calls[idx].id = expected_tool_calls[idx].id
|
||||
|
||||
assert reconstructor.tool_calls == expected_tool_calls
|
||||
@@ -0,0 +1,262 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.protocol import ExtractedToolCallInformation
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers.llama_tool_parser import Llama3JsonToolParser
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def parser(default_tokenizer: TokenizerLike):
|
||||
return Llama3JsonToolParser(default_tokenizer)
|
||||
|
||||
|
||||
def test_extract_tool_calls_simple(parser):
|
||||
# Test with a simple tool call
|
||||
model_output = (
|
||||
'Here is the result: {"name": "getOpenIncidentsTool", '
|
||||
'"parameters": {}} Would you like to know more?'
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, None)
|
||||
|
||||
assert isinstance(result, ExtractedToolCallInformation)
|
||||
assert result.tools_called is True
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].type == "function"
|
||||
assert result.tool_calls[0].function.name == "getOpenIncidentsTool"
|
||||
assert result.tool_calls[0].function.arguments == "{}"
|
||||
assert result.content is None
|
||||
|
||||
|
||||
def test_extract_tool_calls_with_arguments(parser):
|
||||
# Test with a tool call that has arguments
|
||||
model_output = (
|
||||
'{"name": "searchTool", "parameters": {"query": "test query", "limit": 10}}'
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, None)
|
||||
|
||||
assert result.tools_called is True
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].function.name == "searchTool"
|
||||
assert '"query": "test query"' in result.tool_calls[0].function.arguments
|
||||
assert '"limit": 10' in result.tool_calls[0].function.arguments
|
||||
|
||||
|
||||
def test_extract_tool_calls_no_json(parser):
|
||||
# Test with text that doesn't contain a JSON object
|
||||
model_output = "This is just some text without any tool calls"
|
||||
result = parser.extract_tool_calls(model_output, None)
|
||||
|
||||
assert result.tools_called is False
|
||||
assert len(result.tool_calls) == 0
|
||||
assert result.content == model_output
|
||||
|
||||
|
||||
def test_extract_tool_calls_invalid_json(parser):
|
||||
# Test with invalid JSON
|
||||
model_output = '{"name": "invalidTool", "parameters": {invalid json}'
|
||||
result = parser.extract_tool_calls(model_output, None)
|
||||
|
||||
assert result.tools_called is False
|
||||
assert len(result.tool_calls) == 0
|
||||
assert result.content == model_output
|
||||
|
||||
|
||||
def test_extract_tool_calls_with_arguments_key(parser):
|
||||
# Test with a tool call that uses "arguments" instead of "parameters"
|
||||
model_output = '{"name": "searchTool", "arguments": {"query": "test"}}'
|
||||
result = parser.extract_tool_calls(model_output, None)
|
||||
|
||||
assert result.tools_called is True
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].function.name == "searchTool"
|
||||
assert '"query": "test"' in result.tool_calls[0].function.arguments
|
||||
|
||||
|
||||
def test_extract_tool_calls_multiple_json(parser):
|
||||
# Test with multiple JSONs separated by semicolons
|
||||
model_output = (
|
||||
'{"name": "searchTool", "parameters": {"query": "test1"}}; '
|
||||
'{"name": "getOpenIncidentsTool", "parameters": {}}; '
|
||||
'{"name": "searchTool", "parameters": {"query": "test2"}}'
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, None)
|
||||
|
||||
assert result.tools_called is True
|
||||
assert len(result.tool_calls) == 3
|
||||
|
||||
# Check first tool call
|
||||
assert result.tool_calls[0].function.name == "searchTool"
|
||||
assert '"query": "test1"' in result.tool_calls[0].function.arguments
|
||||
|
||||
# Check second tool call
|
||||
assert result.tool_calls[1].function.name == "getOpenIncidentsTool"
|
||||
assert result.tool_calls[1].function.arguments == "{}"
|
||||
|
||||
# Check third tool call
|
||||
assert result.tool_calls[2].function.name == "searchTool"
|
||||
assert '"query": "test2"' in result.tool_calls[2].function.arguments
|
||||
|
||||
|
||||
def test_extract_tool_calls_multiple_json_with_whitespace(parser):
|
||||
# Test with multiple JSONs separated by semicolons and extra whitespace
|
||||
model_output = (
|
||||
'{"name": "searchTool", "parameters": {"query": "test1"}} ; '
|
||||
'{"name": "getOpenIncidentsTool", "parameters": {}} ; '
|
||||
'{"name": "searchTool", "parameters": {"query": "test2"}}'
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, None)
|
||||
|
||||
assert result.tools_called is True
|
||||
assert len(result.tool_calls) == 3
|
||||
assert result.tool_calls[0].function.name == "searchTool"
|
||||
assert result.tool_calls[1].function.name == "getOpenIncidentsTool"
|
||||
assert result.tool_calls[2].function.name == "searchTool"
|
||||
|
||||
|
||||
def test_extract_tool_calls_multiple_json_with_surrounding_text(parser):
|
||||
# Test with multiple JSONs and surrounding text
|
||||
model_output = (
|
||||
"Here are the results: "
|
||||
'{"name": "searchTool", "parameters": {"query": "test1"}}; '
|
||||
'{"name": "getOpenIncidentsTool", "parameters": {}}; '
|
||||
'{"name": "searchTool", "parameters": {"query": "test2"}} '
|
||||
"Would you like to know more?"
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, None)
|
||||
|
||||
assert result.tools_called is True
|
||||
assert len(result.tool_calls) == 3
|
||||
assert result.tool_calls[0].function.name == "searchTool"
|
||||
assert result.tool_calls[1].function.name == "getOpenIncidentsTool"
|
||||
assert result.tool_calls[2].function.name == "searchTool"
|
||||
|
||||
|
||||
def test_extract_tool_calls_deeply_nested_json(parser):
|
||||
# Test with deeply nested JSON parameters (5 levels)
|
||||
model_output = (
|
||||
'{"name": "complexTool", '
|
||||
'"parameters": {'
|
||||
'"level1": {'
|
||||
'"level2": {'
|
||||
'"level3": {'
|
||||
'"level4": {'
|
||||
'"value": "deep"'
|
||||
"}}}}}}"
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, None)
|
||||
|
||||
assert result.tools_called is True
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].function.name == "complexTool"
|
||||
# Verify the nested structure is preserved in the arguments
|
||||
import json
|
||||
|
||||
args = json.loads(result.tool_calls[0].function.arguments)
|
||||
assert args["level1"]["level2"]["level3"]["level4"]["value"] == "deep"
|
||||
|
||||
|
||||
def test_extract_tool_calls_multiple_with_deep_nesting(parser):
|
||||
# Test with multiple tool calls where some have deeply nested parameters
|
||||
model_output = (
|
||||
'{"name": "simpleTool", "parameters": {"value": "test"}}; '
|
||||
'{"name": "complexTool", "parameters": '
|
||||
'{"config": {"database": {"connection": {"pool": {"size": 10}}}}}}'
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, None)
|
||||
|
||||
assert result.tools_called is True
|
||||
assert len(result.tool_calls) == 2
|
||||
|
||||
# Check first tool call
|
||||
assert result.tool_calls[0].function.name == "simpleTool"
|
||||
import json
|
||||
|
||||
args0 = json.loads(result.tool_calls[0].function.arguments)
|
||||
assert args0["value"] == "test"
|
||||
|
||||
# Check second tool call with deep nesting
|
||||
assert result.tool_calls[1].function.name == "complexTool"
|
||||
args1 = json.loads(result.tool_calls[1].function.arguments)
|
||||
assert args1["config"]["database"]["connection"]["pool"]["size"] == 10
|
||||
|
||||
|
||||
def test_extract_tool_calls_with_quotes_and_brackets_in_string(parser):
|
||||
# Test with quotes and brackets inside quoted string values
|
||||
model_output = (
|
||||
'{"name": "searchTool", '
|
||||
'"parameters": {'
|
||||
'"query": "test {value} [complex]",'
|
||||
'"nested": {"inner": "more {brackets}"}'
|
||||
"}}"
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, None)
|
||||
|
||||
assert result.tools_called is True
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].function.name == "searchTool"
|
||||
# Verify the string values are preserved including brackets and quotes
|
||||
import json
|
||||
|
||||
args = json.loads(result.tool_calls[0].function.arguments)
|
||||
assert args["query"] == "test {value} [complex]"
|
||||
assert args["nested"]["inner"] == "more {brackets}"
|
||||
|
||||
|
||||
def test_extract_tool_calls_with_escaped_quotes_in_nested_json(parser):
|
||||
# Test with escaped quotes in deeply nested JSON
|
||||
model_output = (
|
||||
'{"name": "parserTool", "parameters": {"text": "He said \\"Hello {world}\\""}}'
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, None)
|
||||
|
||||
assert result.tools_called is True
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].function.name == "parserTool"
|
||||
# Verify escaped quotes are preserved
|
||||
import json
|
||||
|
||||
args = json.loads(result.tool_calls[0].function.arguments)
|
||||
assert args["text"] == 'He said "Hello {world}"'
|
||||
|
||||
|
||||
def test_extract_tool_calls_missing_name_key(parser):
|
||||
# Test that missing "name" key returns content
|
||||
model_output = '{"parameters": {}}'
|
||||
result = parser.extract_tool_calls(model_output, None)
|
||||
|
||||
assert result.tools_called is False
|
||||
assert len(result.tool_calls) == 0
|
||||
assert result.content == model_output
|
||||
|
||||
|
||||
def test_extract_tool_calls_missing_parameters_and_arguments_key(parser):
|
||||
# Test that missing both "parameters" and "arguments" keys returns content
|
||||
model_output = '{"name": "toolWithoutParams"}'
|
||||
result = parser.extract_tool_calls(model_output, None)
|
||||
|
||||
assert result.tools_called is False
|
||||
assert len(result.tool_calls) == 0
|
||||
assert result.content == model_output
|
||||
|
||||
|
||||
def test_regex_timeout_handling(parser):
|
||||
"""Test regex timeout is handled gracefully"""
|
||||
fake_problematic_input = "{hello world[A(A=" + "\t)A(A=,\t" * 2
|
||||
|
||||
# create a mock regex that raises TimeoutError
|
||||
mock_regex = MagicMock()
|
||||
mock_regex.finditer.side_effect = TimeoutError("Regex timeout")
|
||||
|
||||
with patch.object(parser, "tool_call_start_regex", mock_regex):
|
||||
result = parser.extract_tool_calls(fake_problematic_input, None)
|
||||
|
||||
# should treat as regular text when regex times out
|
||||
assert result.content == fake_problematic_input
|
||||
assert result.tools_called is False
|
||||
assert len(result.tool_calls) == 0
|
||||
mock_regex.finditer.assert_called_once()
|
||||
@@ -0,0 +1,269 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.entrypoints.openai.tool_parsers.utils import (
|
||||
run_tool_extraction,
|
||||
run_tool_extraction_streaming,
|
||||
)
|
||||
from vllm.entrypoints.openai.protocol import FunctionCall
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers import ToolParser, ToolParserManager
|
||||
|
||||
# Test cases similar to pythonic parser but with Llama4 specific format
|
||||
SIMPLE_FUNCTION_OUTPUT = "[get_weather(city='LA', metric='C')]"
|
||||
SIMPLE_FUNCTION_CALL = FunctionCall(
|
||||
name="get_weather",
|
||||
arguments='{"city": "LA", "metric": "C"}',
|
||||
)
|
||||
MORE_TYPES_FUNCTION_OUTPUT = (
|
||||
"[register_user(name='Doe', "
|
||||
"age=9, "
|
||||
"address={'city': 'LA', 'state': 'CA'}, "
|
||||
"role=None, "
|
||||
"passed_test=True, "
|
||||
"aliases=['John', 'Johnny'])]"
|
||||
)
|
||||
MORE_TYPES_FUNCTION_CALL = FunctionCall(
|
||||
name="register_user",
|
||||
arguments='{"name": "Doe", '
|
||||
'"age": 9, '
|
||||
'"address": {"city": "LA", "state": "CA"}, '
|
||||
'"role": null, '
|
||||
'"passed_test": true, '
|
||||
'"aliases": ["John", "Johnny"]}',
|
||||
)
|
||||
PARAMETERLESS_FUNCTION_OUTPUT = "[get_weather()]"
|
||||
PARAMETERLESS_FUNCTION_CALL = FunctionCall(
|
||||
name="get_weather",
|
||||
arguments="{}",
|
||||
)
|
||||
EMPTY_DICT_FUNCTION_OUTPUT = "[do_something_cool(additional_data={})]"
|
||||
EMPTY_DICT_FUNCTION_CALL = FunctionCall(
|
||||
name="do_something_cool",
|
||||
arguments='{"additional_data": {}}',
|
||||
)
|
||||
EMPTY_LIST_FUNCTION_OUTPUT = "[do_something_cool(steps=[])]"
|
||||
EMPTY_LIST_FUNCTION_CALL = FunctionCall(
|
||||
name="do_something_cool",
|
||||
arguments='{"steps": []}',
|
||||
)
|
||||
ESCAPED_STRING_FUNCTION_OUTPUT = (
|
||||
r"[get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')]"
|
||||
)
|
||||
ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
|
||||
name="get_weather",
|
||||
arguments='{"city": "Martha\'s Vineyard", "metric": "\\"cool units\\""}',
|
||||
)
|
||||
PYTHON_TAG_FUNCTION_OUTPUT = (
|
||||
"<|python_start|>[get_weather(city='LA', metric='C')]<|python_end|>"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [True, False])
|
||||
def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike):
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
|
||||
default_tokenizer
|
||||
)
|
||||
model_output = "How can I help you today?"
|
||||
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, model_output, streaming=streaming
|
||||
)
|
||||
|
||||
assert content == model_output
|
||||
assert len(tool_calls) == 0
|
||||
|
||||
|
||||
test_str = "<|python_start|>"
|
||||
test_str += "[get_weather(city='LA', metric='C'),"
|
||||
test_str += "register_user(name='Doe', age=9)]"
|
||||
TEST_CASES = [
|
||||
pytest.param(
|
||||
True,
|
||||
ESCAPED_STRING_FUNCTION_OUTPUT,
|
||||
[ESCAPED_STRING_FUNCTION_CALL],
|
||||
id="simple_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False, SIMPLE_FUNCTION_OUTPUT, [SIMPLE_FUNCTION_CALL], id="simple_nonstreaming"
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
MORE_TYPES_FUNCTION_OUTPUT,
|
||||
[MORE_TYPES_FUNCTION_CALL],
|
||||
id="more_types_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
MORE_TYPES_FUNCTION_OUTPUT,
|
||||
[MORE_TYPES_FUNCTION_CALL],
|
||||
id="more_types_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
PARAMETERLESS_FUNCTION_OUTPUT,
|
||||
[PARAMETERLESS_FUNCTION_CALL],
|
||||
id="parameterless_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
PARAMETERLESS_FUNCTION_OUTPUT,
|
||||
[PARAMETERLESS_FUNCTION_CALL],
|
||||
id="parameterless_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
EMPTY_DICT_FUNCTION_OUTPUT,
|
||||
[EMPTY_DICT_FUNCTION_CALL],
|
||||
id="empty_dict_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
EMPTY_DICT_FUNCTION_OUTPUT,
|
||||
[EMPTY_DICT_FUNCTION_CALL],
|
||||
id="empty_dict_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
EMPTY_LIST_FUNCTION_OUTPUT,
|
||||
[EMPTY_LIST_FUNCTION_CALL],
|
||||
id="empty_list_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
EMPTY_LIST_FUNCTION_OUTPUT,
|
||||
[EMPTY_LIST_FUNCTION_CALL],
|
||||
id="empty_list_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
ESCAPED_STRING_FUNCTION_OUTPUT,
|
||||
[ESCAPED_STRING_FUNCTION_CALL],
|
||||
id="escaped_string_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
ESCAPED_STRING_FUNCTION_OUTPUT,
|
||||
[ESCAPED_STRING_FUNCTION_CALL],
|
||||
id="escaped_string_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
"[get_weather(city='LA',metric='C'),register_user(name='Doe',age=9)]",
|
||||
[
|
||||
SIMPLE_FUNCTION_CALL,
|
||||
FunctionCall(name="register_user", arguments='{"name": "Doe", "age": 9}'),
|
||||
],
|
||||
id="parallel_calls_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
"[get_weather(city='LA',metric='C'),register_user(name='Doe',age=9)]",
|
||||
[
|
||||
SIMPLE_FUNCTION_CALL,
|
||||
FunctionCall(name="register_user", arguments='{"name": "Doe", "age": 9}'),
|
||||
],
|
||||
id="parallel_calls_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
PYTHON_TAG_FUNCTION_OUTPUT,
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
id="python_tag_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
PYTHON_TAG_FUNCTION_OUTPUT,
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
id="python_tag_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
test_str,
|
||||
[
|
||||
SIMPLE_FUNCTION_CALL,
|
||||
FunctionCall(name="register_user", arguments='{"name": "Doe", "age": 9}'),
|
||||
],
|
||||
id="parallel_calls_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
"<|python_start|>[get_weather(city='LA', metric='C'), "
|
||||
+ "register_user(name='Doe', age=9)]",
|
||||
[
|
||||
SIMPLE_FUNCTION_CALL,
|
||||
FunctionCall(name="register_user", arguments='{"name": "Doe", "age": 9}'),
|
||||
],
|
||||
id="parallel_calls_nonstreaming",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", TEST_CASES)
|
||||
def test_tool_call(
|
||||
streaming: bool,
|
||||
model_output: str,
|
||||
expected_tool_calls: list[FunctionCall],
|
||||
default_tokenizer: TokenizerLike,
|
||||
):
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
|
||||
default_tokenizer
|
||||
)
|
||||
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, model_output, streaming=streaming
|
||||
)
|
||||
|
||||
assert len(tool_calls) == len(expected_tool_calls)
|
||||
for actual, expected in zip(tool_calls, expected_tool_calls):
|
||||
assert actual.type == "function"
|
||||
assert actual.function == expected
|
||||
|
||||
|
||||
def test_streaming_tool_call_with_large_steps(default_tokenizer: TokenizerLike):
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
|
||||
default_tokenizer
|
||||
)
|
||||
model_output_deltas = [
|
||||
"<|python_start|>[get_weather(city='LA', metric='C'), "
|
||||
"get_weather(), "
|
||||
"do_something_cool(steps=[])]<|python_end|>",
|
||||
]
|
||||
|
||||
reconstructor = run_tool_extraction_streaming(
|
||||
tool_parser, model_output_deltas, assert_one_tool_per_delta=False
|
||||
)
|
||||
|
||||
assert reconstructor.other_content == ""
|
||||
assert len(reconstructor.tool_calls) == 3
|
||||
assert reconstructor.tool_calls[0].function == SIMPLE_FUNCTION_CALL
|
||||
assert reconstructor.tool_calls[1].function == PARAMETERLESS_FUNCTION_CALL
|
||||
assert reconstructor.tool_calls[2].function == EMPTY_LIST_FUNCTION_CALL
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [False])
|
||||
def test_regex_timeout_handling(streaming: bool, default_tokenizer: TokenizerLike):
|
||||
"""test regex timeout is handled gracefully"""
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
|
||||
default_tokenizer
|
||||
)
|
||||
|
||||
fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2
|
||||
|
||||
# create a mock regex that raises TimeoutError
|
||||
mock_regex = MagicMock()
|
||||
mock_regex.match.side_effect = TimeoutError("Regex timeout")
|
||||
|
||||
with patch.object(tool_parser, "TOOL_CALL_REGEX", mock_regex):
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, fake_problematic_input, streaming=streaming
|
||||
)
|
||||
|
||||
# should treat as regular text when regex times out
|
||||
assert content == fake_problematic_input
|
||||
assert len(tool_calls) == 0
|
||||
mock_regex.match.assert_called_once()
|
||||
251
tests/entrypoints/openai/tool_parsers/test_olmo3_tool_parser.py
Normal file
251
tests/entrypoints/openai/tool_parsers/test_olmo3_tool_parser.py
Normal file
@@ -0,0 +1,251 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.entrypoints.openai.tool_parsers.utils import (
|
||||
run_tool_extraction,
|
||||
run_tool_extraction_streaming,
|
||||
)
|
||||
from vllm.entrypoints.openai.protocol import FunctionCall
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers import ToolParser, ToolParserManager
|
||||
|
||||
# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1
|
||||
SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')"
|
||||
SIMPLE_FUNCTION_CALL = FunctionCall(
|
||||
name="get_weather",
|
||||
arguments='{"city": "San Francisco", "metric": "celsius"}',
|
||||
)
|
||||
MORE_TYPES_FUNCTION_OUTPUT = (
|
||||
"register_user(name='John Doe', "
|
||||
"age=37, "
|
||||
"address={'city': 'San Francisco', 'state': 'CA'}, "
|
||||
"role=None, "
|
||||
"passed_test=True, "
|
||||
"aliases=['John', 'Johnny'])"
|
||||
)
|
||||
MORE_TYPES_FUNCTION_OUTPUT_JSON_LITERALS = (
|
||||
"register_user(name='John Doe', "
|
||||
"age=37, "
|
||||
"address={'city': 'San Francisco', 'state': 'CA'}, "
|
||||
"role=null, "
|
||||
"passed_test=true, "
|
||||
"aliases=['John', 'Johnny'])"
|
||||
)
|
||||
MORE_TYPES_FUNCTION_CALL = FunctionCall(
|
||||
name="register_user",
|
||||
arguments='{"name": "John Doe", '
|
||||
'"age": 37, '
|
||||
'"address": {"city": "San Francisco", "state": "CA"}, '
|
||||
'"role": null, '
|
||||
'"passed_test": true, '
|
||||
'"aliases": ["John", "Johnny"]}',
|
||||
)
|
||||
PARAMETERLESS_FUNCTION_OUTPUT = "get_weather()"
|
||||
PARAMETERLESS_FUNCTION_CALL = FunctionCall(
|
||||
name="get_weather",
|
||||
arguments="{}",
|
||||
)
|
||||
EMPTY_DICT_FUNCTION_OUTPUT = "do_something_cool(additional_data={})"
|
||||
EMPTY_DICT_FUNCTION_CALL = FunctionCall(
|
||||
name="do_something_cool",
|
||||
arguments='{"additional_data": {}}',
|
||||
)
|
||||
EMPTY_LIST_FUNCTION_OUTPUT = "do_something_cool(steps=[])"
|
||||
EMPTY_LIST_FUNCTION_CALL = FunctionCall(
|
||||
name="do_something_cool",
|
||||
arguments='{"steps": []}',
|
||||
)
|
||||
ESCAPED_STRING_FUNCTION_OUTPUT = (
|
||||
r"get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')"
|
||||
)
|
||||
ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
|
||||
name="get_weather",
|
||||
arguments='{"city": "Martha\'s Vineyard", "metric": "\\"cool units\\""}',
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [True, False])
|
||||
def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike):
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
|
||||
default_tokenizer
|
||||
)
|
||||
model_output = "How can I help you today?"
|
||||
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, model_output, streaming=streaming
|
||||
)
|
||||
|
||||
assert content == model_output
|
||||
assert len(tool_calls) == 0
|
||||
|
||||
|
||||
TEST_CASES = [
|
||||
pytest.param(
|
||||
True,
|
||||
f"<function_calls>{SIMPLE_FUNCTION_OUTPUT}</function_calls>",
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
id="simple_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"<function_calls>{SIMPLE_FUNCTION_OUTPUT}</function_calls>",
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
id="simple_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"<function_calls>{MORE_TYPES_FUNCTION_OUTPUT}</function_calls>",
|
||||
[MORE_TYPES_FUNCTION_CALL],
|
||||
id="more_types_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"<function_calls>{MORE_TYPES_FUNCTION_OUTPUT}</function_calls>",
|
||||
[MORE_TYPES_FUNCTION_CALL],
|
||||
id="more_types_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"<function_calls>{MORE_TYPES_FUNCTION_OUTPUT_JSON_LITERALS}</function_calls>",
|
||||
[MORE_TYPES_FUNCTION_CALL],
|
||||
id="more_types_streaming_json_literals",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"<function_calls>{MORE_TYPES_FUNCTION_OUTPUT_JSON_LITERALS}</function_calls>",
|
||||
[MORE_TYPES_FUNCTION_CALL],
|
||||
id="more_types_nonstreaming_json_literals",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"<function_calls>{PARAMETERLESS_FUNCTION_OUTPUT}</function_calls>",
|
||||
[PARAMETERLESS_FUNCTION_CALL],
|
||||
id="parameterless_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"<function_calls>{PARAMETERLESS_FUNCTION_OUTPUT}</function_calls>",
|
||||
[PARAMETERLESS_FUNCTION_CALL],
|
||||
id="parameterless_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"<function_calls>{EMPTY_DICT_FUNCTION_OUTPUT}</function_calls>",
|
||||
[EMPTY_DICT_FUNCTION_CALL],
|
||||
id="empty_dict_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"<function_calls>{EMPTY_DICT_FUNCTION_OUTPUT}</function_calls>",
|
||||
[EMPTY_DICT_FUNCTION_CALL],
|
||||
id="empty_dict_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"<function_calls>{EMPTY_LIST_FUNCTION_OUTPUT}</function_calls>",
|
||||
[EMPTY_LIST_FUNCTION_CALL],
|
||||
id="empty_list_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"<function_calls>{EMPTY_LIST_FUNCTION_OUTPUT}</function_calls>",
|
||||
[EMPTY_LIST_FUNCTION_CALL],
|
||||
id="empty_list_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"<function_calls>{ESCAPED_STRING_FUNCTION_OUTPUT}</function_calls>",
|
||||
[ESCAPED_STRING_FUNCTION_CALL],
|
||||
id="escaped_string_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"<function_calls>{ESCAPED_STRING_FUNCTION_OUTPUT}</function_calls>",
|
||||
[ESCAPED_STRING_FUNCTION_CALL],
|
||||
id="escaped_string_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"<function_calls>{SIMPLE_FUNCTION_OUTPUT}\n{MORE_TYPES_FUNCTION_OUTPUT}</function_calls>",
|
||||
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL],
|
||||
id="parallel_calls_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"<function_calls>{SIMPLE_FUNCTION_OUTPUT}\n{MORE_TYPES_FUNCTION_OUTPUT}</function_calls>",
|
||||
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL],
|
||||
id="parallel_calls_nonstreaming",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", TEST_CASES)
|
||||
def test_tool_call(
|
||||
streaming: bool,
|
||||
model_output: str,
|
||||
expected_tool_calls: list[FunctionCall],
|
||||
default_tokenizer: TokenizerLike,
|
||||
):
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
|
||||
default_tokenizer
|
||||
)
|
||||
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, model_output, streaming=streaming
|
||||
)
|
||||
|
||||
assert content is None
|
||||
assert len(tool_calls) == len(expected_tool_calls)
|
||||
for actual, expected in zip(tool_calls, expected_tool_calls):
|
||||
assert actual.type == "function"
|
||||
assert actual.function == expected
|
||||
|
||||
|
||||
def test_streaming_tool_call_with_large_steps(default_tokenizer: TokenizerLike):
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
|
||||
default_tokenizer
|
||||
)
|
||||
model_output_deltas = [
|
||||
"<function_calls>get_weather(city='San",
|
||||
" Francisco', metric='celsius')\n"
|
||||
f"{PARAMETERLESS_FUNCTION_OUTPUT}\n"
|
||||
f"{EMPTY_LIST_FUNCTION_OUTPUT}</function_calls>",
|
||||
]
|
||||
|
||||
reconstructor = run_tool_extraction_streaming(
|
||||
tool_parser, model_output_deltas, assert_one_tool_per_delta=False
|
||||
)
|
||||
|
||||
assert reconstructor.other_content == ""
|
||||
assert len(reconstructor.tool_calls) == 3
|
||||
assert reconstructor.tool_calls[0].function == SIMPLE_FUNCTION_CALL
|
||||
assert reconstructor.tool_calls[1].function == PARAMETERLESS_FUNCTION_CALL
|
||||
assert reconstructor.tool_calls[2].function == EMPTY_LIST_FUNCTION_CALL
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [False])
|
||||
def test_regex_timeout_handling(streaming: bool, default_tokenizer: TokenizerLike):
|
||||
"""test regex timeout is handled gracefully"""
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
|
||||
default_tokenizer
|
||||
)
|
||||
|
||||
fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2
|
||||
|
||||
# create a mock regex that raises TimeoutError
|
||||
mock_regex = MagicMock()
|
||||
mock_regex.match.side_effect = TimeoutError("Regex timeout")
|
||||
|
||||
with patch.object(tool_parser, "TOOL_CALL_REGEX", mock_regex):
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, fake_problematic_input, streaming=streaming
|
||||
)
|
||||
|
||||
# should treat as regular text when regex times out
|
||||
assert content == fake_problematic_input
|
||||
assert len(tool_calls) == 0
|
||||
mock_regex.match.assert_called_once()
|
||||
359
tests/entrypoints/openai/tool_parsers/test_openai_tool_parser.py
Normal file
359
tests/entrypoints/openai/tool_parsers/test_openai_tool_parser.py
Normal file
@@ -0,0 +1,359 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
|
||||
import jsonschema
|
||||
import openai
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from rapidfuzz import fuzz
|
||||
|
||||
from ....utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "openai/gpt-oss-20b"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = [
|
||||
"--max-model-len",
|
||||
"8192",
|
||||
"--enforce-eager",
|
||||
"--enable-auto-tool-choice",
|
||||
"--tool-call-parser",
|
||||
"openai",
|
||||
]
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
"""Async fixture providing an OpenAI-compatible vLLM client."""
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
# ==========================================================
|
||||
# Tool Definitions
|
||||
# ==========================================================
|
||||
TOOLS = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "calculator",
|
||||
"description": "Performs basic arithmetic calculations.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"expression": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Arithmetic expression to evaluate, e.g. '123 + 456'."
|
||||
),
|
||||
}
|
||||
},
|
||||
"required": ["expression"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_time",
|
||||
"description": "Retrieves the current local time for a given city.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "City name, e.g. 'New York'.",
|
||||
}
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# ==========================================================
|
||||
# Message Examples
|
||||
# ==========================================================
|
||||
MESSAGES_CALC = [
|
||||
{"role": "user", "content": "Calculate 123 + 456 using the calculator."}
|
||||
]
|
||||
|
||||
MESSAGES_GET_TIME = [
|
||||
{"role": "user", "content": "What is the current time in New York?"}
|
||||
]
|
||||
|
||||
MESSAGES_MULTIPLE_CALLS = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You can call multiple tools. "
|
||||
"When using more than one, return single JSON object with tool_calls array"
|
||||
"containing each tool call with its function name and arguments. "
|
||||
"Do not output multiple JSON objects separately."
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "First, calculate 7 * 8 using the calculator. "
|
||||
"Then, use get_time to tell me the current time in New York.",
|
||||
},
|
||||
]
|
||||
|
||||
MESSAGES_INVALID_CALL = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Can you help with something, "
|
||||
"but don’t actually perform any calculation?",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
# Expected outputs
|
||||
FUNC_CALC = "calculator"
|
||||
FUNC_ARGS_CALC = '{"expression":"123 + 456"}'
|
||||
|
||||
FUNC_TIME = "get_time"
|
||||
FUNC_ARGS_TIME = '{"city": "New York"}'
|
||||
|
||||
|
||||
# ==========================================================
|
||||
# Utility to extract reasoning and tool calls
|
||||
# ==========================================================
|
||||
def extract_reasoning_and_calls(chunks: list) -> tuple[str, list[str], list[str]]:
|
||||
"""
|
||||
Extract accumulated reasoning text and tool call arguments
|
||||
from streaming chunks.
|
||||
"""
|
||||
reasoning_content: str = ""
|
||||
tool_calls: dict[int, dict[str, str]] = {}
|
||||
|
||||
for chunk in chunks:
|
||||
choice = getattr(chunk.choices[0], "delta", None)
|
||||
if not choice:
|
||||
continue
|
||||
|
||||
if hasattr(choice, "reasoning_content") and choice.reasoning_content:
|
||||
reasoning_content += choice.reasoning_content
|
||||
|
||||
for tc in getattr(choice, "tool_calls", []) or []:
|
||||
idx = getattr(tc, "index", 0)
|
||||
tool_entry = tool_calls.setdefault(idx, {"name": "", "arguments": ""})
|
||||
|
||||
if getattr(tc, "function", None):
|
||||
func = tc.function
|
||||
if getattr(func, "name", None):
|
||||
tool_entry["name"] = func.name
|
||||
if getattr(func, "arguments", None):
|
||||
tool_entry["arguments"] += func.arguments
|
||||
|
||||
function_names: list[str] = [v["name"] for _, v in sorted(tool_calls.items())]
|
||||
arguments: list[str] = [v["arguments"] for _, v in sorted(tool_calls.items())]
|
||||
|
||||
return reasoning_content, arguments, function_names
|
||||
|
||||
|
||||
# ==========================================================
|
||||
# Test Scenarios
|
||||
# ==========================================================
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculator_tool_call_and_argument_accuracy(client: openai.AsyncOpenAI):
|
||||
"""Verify calculator tool call is made and arguments are accurate."""
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=MESSAGES_CALC,
|
||||
tools=TOOLS,
|
||||
temperature=0.0,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
message = response.choices[0].message
|
||||
tool_calls = getattr(message, "tool_calls", [])
|
||||
assert tool_calls, "No tool calls detected"
|
||||
|
||||
calc_call = next((c for c in tool_calls if c.function.name == FUNC_CALC), None)
|
||||
assert calc_call, "Calculator function not called"
|
||||
|
||||
raw_args = calc_call.function.arguments
|
||||
assert raw_args, "Calculator arguments missing"
|
||||
assert "123" in raw_args and "456" in raw_args, (
|
||||
f"Expected values not in raw arguments: {raw_args}"
|
||||
)
|
||||
|
||||
try:
|
||||
parsed_args = json.loads(raw_args)
|
||||
except json.JSONDecodeError:
|
||||
pytest.fail(f"Invalid JSON in calculator arguments: {raw_args}")
|
||||
|
||||
expected_expr = "123 + 456"
|
||||
actual_expr = parsed_args.get("expression", "")
|
||||
similarity = fuzz.ratio(actual_expr, expected_expr)
|
||||
|
||||
assert similarity > 90, (
|
||||
f"Expression mismatch: expected '{expected_expr}' "
|
||||
f"got '{actual_expr}' (similarity={similarity}%)"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_tool_call_get_time_with_reasoning(client: openai.AsyncOpenAI):
|
||||
"""Verify streamed reasoning and tool call behavior for get_time."""
|
||||
|
||||
stream = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=MESSAGES_GET_TIME,
|
||||
tools=TOOLS,
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
chunks = [chunk async for chunk in stream]
|
||||
reasoning, arguments, function_names = extract_reasoning_and_calls(chunks)
|
||||
|
||||
assert FUNC_TIME in function_names, "get_time function not called"
|
||||
|
||||
assert any("New York" in arg for arg in arguments), (
|
||||
f"Expected get_time arguments for New York not found in {arguments}"
|
||||
)
|
||||
|
||||
assert len(reasoning) > 0, "Expected reasoning content missing"
|
||||
|
||||
assert any(keyword in reasoning for keyword in ["New York", "time", "current"]), (
|
||||
f"Reasoning is not relevant to the request: {reasoning}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_multiple_tools(client: openai.AsyncOpenAI):
|
||||
"""Test streamed multi-tool response with reasoning."""
|
||||
stream = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=MESSAGES_MULTIPLE_CALLS,
|
||||
tools=TOOLS,
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
chunks = [chunk async for chunk in stream]
|
||||
reasoning, arguments, function_names = extract_reasoning_and_calls(chunks)
|
||||
|
||||
try:
|
||||
assert FUNC_CALC in function_names, (
|
||||
f"Calculator tool missing — found {function_names}"
|
||||
)
|
||||
assert FUNC_TIME in function_names, (
|
||||
f"Time tool missing — found {function_names}"
|
||||
)
|
||||
assert len(reasoning) > 0, "Expected reasoning content in streamed response"
|
||||
except AssertionError as e:
|
||||
print(f"ERROR: {e}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_tool_call(client: openai.AsyncOpenAI):
|
||||
"""
|
||||
Verify that ambiguous instructions that should not trigger a tool
|
||||
do not produce any tool calls.
|
||||
"""
|
||||
response = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=MESSAGES_INVALID_CALL,
|
||||
tools=TOOLS,
|
||||
temperature=0.0,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
message = response.choices[0].message
|
||||
|
||||
assert message is not None, "Expected message in response"
|
||||
assert hasattr(message, "content"), "Expected 'content' field in message"
|
||||
|
||||
tool_calls = getattr(message, "tool_calls", [])
|
||||
assert not tool_calls, (
|
||||
f"Model unexpectedly attempted a tool call on invalid input: {tool_calls}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_with_temperature(client: openai.AsyncOpenAI):
|
||||
"""
|
||||
Verify model produces valid tool or text output
|
||||
under non-deterministic sampling.
|
||||
"""
|
||||
response = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=MESSAGES_CALC,
|
||||
tools=TOOLS,
|
||||
temperature=0.7,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
message = response.choices[0].message
|
||||
assert message is not None, "Expected non-empty message in response"
|
||||
assert message.tool_calls or message.content, (
|
||||
"Response missing both text and tool calls"
|
||||
)
|
||||
|
||||
print(f"\nTool calls: {message.tool_calls}")
|
||||
print(f"Text: {message.content}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_response_schema_accuracy(client: openai.AsyncOpenAI):
|
||||
"""Validate that tool call arguments adhere to their declared JSON schema."""
|
||||
response = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=MESSAGES_MULTIPLE_CALLS,
|
||||
tools=TOOLS,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
calls = response.choices[0].message.tool_calls
|
||||
assert calls, "No tool calls produced"
|
||||
|
||||
for call in calls:
|
||||
func_name = call.function.name
|
||||
args = json.loads(call.function.arguments)
|
||||
|
||||
schema: dict[str, object] | None = None
|
||||
for tool_entry in TOOLS:
|
||||
function_def = tool_entry.get("function")
|
||||
if (
|
||||
function_def
|
||||
and isinstance(function_def, dict)
|
||||
and function_def.get("name") == func_name
|
||||
):
|
||||
schema = function_def.get("parameters")
|
||||
break
|
||||
|
||||
assert schema is not None, f"No matching tool schema found for {func_name}"
|
||||
|
||||
jsonschema.validate(instance=args, schema=schema)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_semantic_consistency_with_temperature(client: openai.AsyncOpenAI):
|
||||
"""Test that temperature variation doesn't cause contradictory reasoning."""
|
||||
responses = []
|
||||
for temp in [0.0, 0.5, 1.0]:
|
||||
resp = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=MESSAGES_CALC,
|
||||
tools=TOOLS,
|
||||
temperature=temp,
|
||||
)
|
||||
text = (resp.choices[0].message.content or "").strip()
|
||||
responses.append(text)
|
||||
|
||||
# Compare fuzzy similarity between low- and mid-temperature outputs
|
||||
low_mid_sim = fuzz.ratio(responses[0], responses[1])
|
||||
assert low_mid_sim > 60, (
|
||||
f"Semantic drift too large between T=0.0 and T=0.5 ({low_mid_sim}%)"
|
||||
)
|
||||
@@ -0,0 +1,231 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.entrypoints.openai.tool_parsers.utils import (
|
||||
run_tool_extraction,
|
||||
run_tool_extraction_streaming,
|
||||
)
|
||||
from vllm.entrypoints.openai.protocol import FunctionCall
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers import ToolParser, ToolParserManager
|
||||
|
||||
# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1
|
||||
SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')"
|
||||
SIMPLE_FUNCTION_CALL = FunctionCall(
|
||||
name="get_weather",
|
||||
arguments='{"city": "San Francisco", "metric": "celsius"}',
|
||||
)
|
||||
MORE_TYPES_FUNCTION_OUTPUT = (
|
||||
"register_user(name='John Doe', "
|
||||
"age=37, "
|
||||
"address={'city': 'San Francisco', 'state': 'CA'}, "
|
||||
"role=None, "
|
||||
"passed_test=True, "
|
||||
"aliases=['John', 'Johnny'])"
|
||||
)
|
||||
MORE_TYPES_FUNCTION_CALL = FunctionCall(
|
||||
name="register_user",
|
||||
arguments='{"name": "John Doe", '
|
||||
'"age": 37, '
|
||||
'"address": {"city": "San Francisco", "state": "CA"}, '
|
||||
'"role": null, '
|
||||
'"passed_test": true, '
|
||||
'"aliases": ["John", "Johnny"]}',
|
||||
)
|
||||
PARAMETERLESS_FUNCTION_OUTPUT = "get_weather()"
|
||||
PARAMETERLESS_FUNCTION_CALL = FunctionCall(
|
||||
name="get_weather",
|
||||
arguments="{}",
|
||||
)
|
||||
EMPTY_DICT_FUNCTION_OUTPUT = "do_something_cool(additional_data={})"
|
||||
EMPTY_DICT_FUNCTION_CALL = FunctionCall(
|
||||
name="do_something_cool",
|
||||
arguments='{"additional_data": {}}',
|
||||
)
|
||||
EMPTY_LIST_FUNCTION_OUTPUT = "do_something_cool(steps=[])"
|
||||
EMPTY_LIST_FUNCTION_CALL = FunctionCall(
|
||||
name="do_something_cool",
|
||||
arguments='{"steps": []}',
|
||||
)
|
||||
ESCAPED_STRING_FUNCTION_OUTPUT = (
|
||||
r"get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')"
|
||||
)
|
||||
ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
|
||||
name="get_weather",
|
||||
arguments='{"city": "Martha\'s Vineyard", "metric": "\\"cool units\\""}',
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [True, False])
|
||||
def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike):
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
|
||||
default_tokenizer
|
||||
)
|
||||
model_output = "How can I help you today?"
|
||||
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, model_output, streaming=streaming
|
||||
)
|
||||
|
||||
assert content == model_output
|
||||
assert len(tool_calls) == 0
|
||||
|
||||
|
||||
TEST_CASES = [
|
||||
pytest.param(
|
||||
True,
|
||||
f"[{SIMPLE_FUNCTION_OUTPUT}]",
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
id="simple_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"[{SIMPLE_FUNCTION_OUTPUT}]",
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
id="simple_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"[{MORE_TYPES_FUNCTION_OUTPUT}]",
|
||||
[MORE_TYPES_FUNCTION_CALL],
|
||||
id="more_types_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"[{MORE_TYPES_FUNCTION_OUTPUT}]",
|
||||
[MORE_TYPES_FUNCTION_CALL],
|
||||
id="more_types_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"[{PARAMETERLESS_FUNCTION_OUTPUT}]",
|
||||
[PARAMETERLESS_FUNCTION_CALL],
|
||||
id="parameterless_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"[{PARAMETERLESS_FUNCTION_OUTPUT}]",
|
||||
[PARAMETERLESS_FUNCTION_CALL],
|
||||
id="parameterless_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"[{EMPTY_DICT_FUNCTION_OUTPUT}]",
|
||||
[EMPTY_DICT_FUNCTION_CALL],
|
||||
id="empty_dict_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"[{EMPTY_DICT_FUNCTION_OUTPUT}]",
|
||||
[EMPTY_DICT_FUNCTION_CALL],
|
||||
id="empty_dict_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"[{EMPTY_LIST_FUNCTION_OUTPUT}]",
|
||||
[EMPTY_LIST_FUNCTION_CALL],
|
||||
id="empty_list_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"[{EMPTY_LIST_FUNCTION_OUTPUT}]",
|
||||
[EMPTY_LIST_FUNCTION_CALL],
|
||||
id="empty_list_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"[{ESCAPED_STRING_FUNCTION_OUTPUT}]",
|
||||
[ESCAPED_STRING_FUNCTION_CALL],
|
||||
id="escaped_string_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"[{ESCAPED_STRING_FUNCTION_OUTPUT}]",
|
||||
[ESCAPED_STRING_FUNCTION_CALL],
|
||||
id="escaped_string_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"[{SIMPLE_FUNCTION_OUTPUT}, {MORE_TYPES_FUNCTION_OUTPUT}]",
|
||||
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL],
|
||||
id="parallel_calls_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"[{SIMPLE_FUNCTION_OUTPUT}, {MORE_TYPES_FUNCTION_OUTPUT}]",
|
||||
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL],
|
||||
id="parallel_calls_nonstreaming",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", TEST_CASES)
|
||||
def test_tool_call(
|
||||
streaming: bool,
|
||||
model_output: str,
|
||||
expected_tool_calls: list[FunctionCall],
|
||||
default_tokenizer: TokenizerLike,
|
||||
):
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
|
||||
default_tokenizer
|
||||
)
|
||||
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, model_output, streaming=streaming
|
||||
)
|
||||
|
||||
assert content is None
|
||||
assert len(tool_calls) == len(expected_tool_calls)
|
||||
for actual, expected in zip(tool_calls, expected_tool_calls):
|
||||
assert actual.type == "function"
|
||||
assert actual.function == expected
|
||||
|
||||
|
||||
def test_streaming_tool_call_with_large_steps(default_tokenizer: TokenizerLike):
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
|
||||
default_tokenizer
|
||||
)
|
||||
model_output_deltas = [
|
||||
"[get_weather(city='San",
|
||||
" Francisco', metric='celsius'), "
|
||||
f"{PARAMETERLESS_FUNCTION_OUTPUT}, "
|
||||
f"{EMPTY_LIST_FUNCTION_OUTPUT}]",
|
||||
]
|
||||
|
||||
reconstructor = run_tool_extraction_streaming(
|
||||
tool_parser, model_output_deltas, assert_one_tool_per_delta=False
|
||||
)
|
||||
|
||||
assert reconstructor.other_content == ""
|
||||
assert len(reconstructor.tool_calls) == 3
|
||||
assert reconstructor.tool_calls[0].function == SIMPLE_FUNCTION_CALL
|
||||
assert reconstructor.tool_calls[1].function == PARAMETERLESS_FUNCTION_CALL
|
||||
assert reconstructor.tool_calls[2].function == EMPTY_LIST_FUNCTION_CALL
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [False])
|
||||
def test_regex_timeout_handling(streaming: bool, default_tokenizer: TokenizerLike):
|
||||
"""test regex timeout is handled gracefully"""
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
|
||||
default_tokenizer
|
||||
)
|
||||
|
||||
fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2
|
||||
|
||||
# create a mock regex that raises TimeoutError
|
||||
mock_regex = MagicMock()
|
||||
mock_regex.match.side_effect = TimeoutError("Regex timeout")
|
||||
|
||||
with patch.object(tool_parser, "TOOL_CALL_REGEX", mock_regex):
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, fake_problematic_input, streaming=streaming
|
||||
)
|
||||
|
||||
# should treat as regular text when regex times out
|
||||
assert content == fake_problematic_input
|
||||
assert len(tool_calls) == 0
|
||||
mock_regex.match.assert_called_once()
|
||||
167
tests/entrypoints/openai/tool_parsers/utils.py
Normal file
167
tests/entrypoints/openai/tool_parsers/utils.py
Normal file
@@ -0,0 +1,167 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterable
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaMessage,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers import ToolParser
|
||||
|
||||
|
||||
class StreamingToolReconstructor:
|
||||
def __init__(self, assert_one_tool_per_delta: bool = True):
|
||||
self.tool_calls: list[ToolCall] = []
|
||||
self.other_content: str = ""
|
||||
self._assert_one_tool_per_delta = assert_one_tool_per_delta
|
||||
|
||||
def append_delta(self, delta: DeltaMessage):
|
||||
if delta.content is not None:
|
||||
self.other_content += delta.content
|
||||
else:
|
||||
assert delta.tool_calls, (
|
||||
"Streaming results should have either content or tool calls (or both)"
|
||||
)
|
||||
if self._assert_one_tool_per_delta:
|
||||
# Note: This isn't strictly required by the API and may not be
|
||||
# possible to adhere to depending on the token space and number of
|
||||
# tokens per streamed response from the model, but it is required
|
||||
# by tool_use tests, so we enforce it here by default also.
|
||||
assert len(delta.tool_calls) < 2, (
|
||||
"Streaming should include only one tool call per update."
|
||||
)
|
||||
for call_delta in delta.tool_calls:
|
||||
assert call_delta.type is None or call_delta.type == "function", (
|
||||
"Streaming tool calls should only emit function calls. Got "
|
||||
f"{call_delta.type}"
|
||||
)
|
||||
current_tool_call = (
|
||||
self.tool_calls[call_delta.index]
|
||||
if call_delta.index < len(self.tool_calls)
|
||||
else None
|
||||
)
|
||||
if current_tool_call:
|
||||
assert not call_delta.function.name, (
|
||||
"Streaming tool calls should emit the full function name "
|
||||
f"exactly once. Got {call_delta.function.name}"
|
||||
)
|
||||
assert not call_delta.id, (
|
||||
"Streaming tool calls must emit function id only once. Got "
|
||||
f"{call_delta.id}"
|
||||
)
|
||||
assert call_delta.index == len(self.tool_calls) - 1, (
|
||||
f"Incorrect index for tool delta. Got {call_delta.index}, "
|
||||
f"expected {len(self.tool_calls) - 1}"
|
||||
)
|
||||
current_tool_call.function.arguments += call_delta.function.arguments
|
||||
else:
|
||||
assert call_delta.id is not None, (
|
||||
"Streaming tool calls must have an id on first appearance"
|
||||
)
|
||||
assert call_delta.function.name is not None, (
|
||||
"Streaming tool calls must have a function name on first appearance"
|
||||
)
|
||||
assert call_delta.index == len(self.tool_calls), (
|
||||
f"Incorrect index for tool delta. Got {call_delta.index}, "
|
||||
f"expected {len(self.tool_calls)}"
|
||||
)
|
||||
self.tool_calls.append(
|
||||
ToolCall(
|
||||
id=call_delta.id,
|
||||
function=FunctionCall(
|
||||
name=call_delta.function.name,
|
||||
arguments=call_delta.function.arguments or "",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def run_tool_extraction(
|
||||
tool_parser: ToolParser,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest | None = None,
|
||||
streaming: bool = False,
|
||||
assert_one_tool_per_delta: bool = True,
|
||||
) -> tuple[str | None, list[ToolCall]]:
|
||||
if streaming:
|
||||
reconstructor = run_tool_extraction_streaming(
|
||||
tool_parser,
|
||||
model_output,
|
||||
request,
|
||||
assert_one_tool_per_delta=assert_one_tool_per_delta,
|
||||
)
|
||||
return reconstructor.other_content or None, reconstructor.tool_calls
|
||||
else:
|
||||
extracted = run_tool_extraction_nonstreaming(tool_parser, model_output, request)
|
||||
assert extracted.tools_called == bool(extracted.tool_calls)
|
||||
return extracted.content, extracted.tool_calls
|
||||
|
||||
|
||||
def run_tool_extraction_nonstreaming(
|
||||
tool_parser: ToolParser,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest | None = None,
|
||||
) -> ExtractedToolCallInformation:
|
||||
request = request or ChatCompletionRequest(messages=[], model="test-model")
|
||||
return tool_parser.extract_tool_calls(model_output, request)
|
||||
|
||||
|
||||
def split_string_into_token_deltas(tokenizer: TokenizerLike, text: str) -> list[str]:
|
||||
# Split a string into a series of deltas using the provided tokenizer. Each
|
||||
# delta will be the string equivalent of a single token.
|
||||
token_ids = tokenizer.encode(text, add_special_tokens=False)
|
||||
previously_decoded_text = ""
|
||||
deltas = []
|
||||
for i in range(1, len(token_ids) + 1):
|
||||
current_tokens = token_ids[:i]
|
||||
current_text = tokenizer.decode(current_tokens)
|
||||
new_text = current_text[len(previously_decoded_text) :]
|
||||
previously_decoded_text = current_text
|
||||
deltas.append(new_text)
|
||||
return deltas
|
||||
|
||||
|
||||
def run_tool_extraction_streaming(
|
||||
tool_parser: ToolParser,
|
||||
model_deltas: Iterable[str],
|
||||
request: ChatCompletionRequest | None = None,
|
||||
assert_one_tool_per_delta: bool = True,
|
||||
) -> StreamingToolReconstructor:
|
||||
if isinstance(model_deltas, str):
|
||||
model_deltas = split_string_into_token_deltas(
|
||||
tool_parser.model_tokenizer, model_deltas
|
||||
)
|
||||
|
||||
request = request or ChatCompletionRequest(messages=[], model="test-model")
|
||||
reconstructor = StreamingToolReconstructor(
|
||||
assert_one_tool_per_delta=assert_one_tool_per_delta
|
||||
)
|
||||
previous_text = ""
|
||||
previous_tokens: list[int] = []
|
||||
for delta in model_deltas:
|
||||
token_delta = [
|
||||
tool_parser.vocab.get(token)
|
||||
for token in tool_parser.model_tokenizer.tokenize(delta)
|
||||
if token in tool_parser.vocab
|
||||
]
|
||||
current_text = previous_text + delta
|
||||
current_tokens = previous_tokens + token_delta
|
||||
delta_message = tool_parser.extract_tool_calls_streaming(
|
||||
previous_text,
|
||||
current_text,
|
||||
delta,
|
||||
previous_tokens,
|
||||
current_tokens,
|
||||
token_delta,
|
||||
request,
|
||||
)
|
||||
if delta_message is not None:
|
||||
reconstructor.append_delta(delta_message)
|
||||
previous_text = current_text
|
||||
previous_tokens = current_tokens
|
||||
return reconstructor
|
||||
Reference in New Issue
Block a user