Sync from v0.13
This commit is contained in:
0
tests/tool_use/__init__.py
Normal file
0
tests/tool_use/__init__.py
Normal file
68
tests/tool_use/conftest.py
Normal file
68
tests/tool_use/conftest.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .utils import ARGS, CONFIGS, ServerConfig
|
||||
|
||||
|
||||
# select models to test based on command line arguments
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption("--models", nargs="+", help="Specify one or more models to test")
|
||||
parser.addoption(
|
||||
"--extended",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="invoke extended tests requiring large GPUs",
|
||||
)
|
||||
|
||||
|
||||
# for each server config, download the model and return the config
|
||||
@pytest.fixture(scope="session", params=CONFIGS.keys())
|
||||
def server_config(request):
|
||||
extended = request.config.getoption("--extended")
|
||||
models = request.config.getoption("--models")
|
||||
|
||||
config_keys_to_test = [
|
||||
key
|
||||
for key in CONFIGS
|
||||
if (models is None or key in models)
|
||||
and (extended or not CONFIGS[key].get("extended", False))
|
||||
]
|
||||
|
||||
config_key = request.param
|
||||
if config_key not in config_keys_to_test:
|
||||
pytest.skip(f"Skipping config '{config_key}'")
|
||||
|
||||
config = CONFIGS[config_key]
|
||||
|
||||
if current_platform.is_rocm() and not config.get("supports_rocm", True):
|
||||
pytest.skip(
|
||||
"The {} model can't be tested on the ROCm platform".format(config["model"])
|
||||
)
|
||||
|
||||
# download model and tokenizer using transformers
|
||||
snapshot_download(config["model"])
|
||||
yield CONFIGS[request.param]
|
||||
|
||||
|
||||
# run this for each server config
|
||||
@pytest.fixture(scope="session")
|
||||
def server(request, server_config: ServerConfig):
|
||||
model = server_config["model"]
|
||||
args_for_model = server_config["arguments"]
|
||||
with RemoteOpenAIServer(
|
||||
model, ARGS + args_for_model, max_wait_seconds=480
|
||||
) as server:
|
||||
yield server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server: RemoteOpenAIServer):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
0
tests/tool_use/mistral/__init__.py
Normal file
0
tests/tool_use/mistral/__init__.py
Normal file
43
tests/tool_use/mistral/conftest.py
Normal file
43
tests/tool_use/mistral/conftest.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .utils import ARGS, CONFIGS, ServerConfig
|
||||
|
||||
|
||||
# for each server config, download the model and return the config
|
||||
@pytest.fixture(scope="package", params=CONFIGS.keys())
|
||||
def server_config(request):
|
||||
config = CONFIGS[request.param]
|
||||
|
||||
if current_platform.is_rocm() and not config.get("supports_rocm", True):
|
||||
pytest.skip(
|
||||
"The {} model can't be tested on the ROCm platform".format(config["model"])
|
||||
)
|
||||
|
||||
# download model and tokenizer using transformers
|
||||
snapshot_download(config["model"])
|
||||
yield CONFIGS[request.param]
|
||||
|
||||
|
||||
# run this for each server config
|
||||
@pytest.fixture(scope="package")
|
||||
def server(request, server_config: ServerConfig):
|
||||
model = server_config["model"]
|
||||
args_for_model = server_config["arguments"]
|
||||
with RemoteOpenAIServer(
|
||||
model, ARGS + args_for_model, max_wait_seconds=480
|
||||
) as server:
|
||||
yield server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server: RemoteOpenAIServer):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
30
tests/tool_use/mistral/test_mistral_tool_calls.py
Normal file
30
tests/tool_use/mistral/test_mistral_tool_calls.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
from tests.tool_use.utils import MESSAGES_ASKING_FOR_TOOLS, WEATHER_TOOL
|
||||
|
||||
|
||||
# test: a tool_choice with mistral-tokenizer results in an ID of length 9
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_with_tool_choice(client: openai.AsyncOpenAI):
|
||||
models = await client.models.list()
|
||||
model_name: str = models.data[0].id
|
||||
chat_completion = await client.chat.completions.create(
|
||||
messages=MESSAGES_ASKING_FOR_TOOLS,
|
||||
temperature=0,
|
||||
max_completion_tokens=100,
|
||||
model=model_name,
|
||||
tools=[WEATHER_TOOL],
|
||||
tool_choice=WEATHER_TOOL,
|
||||
logprobs=False,
|
||||
)
|
||||
|
||||
choice = chat_completion.choices[0]
|
||||
|
||||
assert choice.finish_reason != "tool_calls" # "stop" or "length"
|
||||
assert choice.message.role == "assistant"
|
||||
assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 1
|
||||
assert len(choice.message.tool_calls[0].id) == 9 # length of 9 for mistral
|
||||
32
tests/tool_use/mistral/utils.py
Normal file
32
tests/tool_use/mistral/utils.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
class ServerConfig(TypedDict, total=False):
|
||||
model: str
|
||||
arguments: list[str]
|
||||
system_prompt: str | None
|
||||
supports_parallel: bool | None
|
||||
supports_rocm: bool | None
|
||||
|
||||
|
||||
ARGS: list[str] = ["--max-model-len", "1024"]
|
||||
|
||||
CONFIGS: dict[str, ServerConfig] = {
|
||||
"mistral": {
|
||||
"model": "mistralai/Mistral-7B-Instruct-v0.3",
|
||||
"arguments": [
|
||||
"--tokenizer-mode",
|
||||
"mistral",
|
||||
'--ignore-patterns="consolidated.safetensors"',
|
||||
],
|
||||
"system_prompt": "You are a helpful assistant with access to tools. If a tool"
|
||||
" that you have would be helpful to answer a user query, "
|
||||
"call the tool. Otherwise, answer the user's query directly "
|
||||
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
|
||||
"to the user's question - just respond to it normally.",
|
||||
},
|
||||
}
|
||||
63
tests/tool_use/test_chat_completion_request_validations.py
Normal file
63
tests/tool_use/test_chat_completion_request_validations.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
|
||||
|
||||
def test_chat_completion_request_with_no_tools():
|
||||
# tools key is not present
|
||||
request = ChatCompletionRequest.model_validate(
|
||||
{
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"model": "facebook/opt-125m",
|
||||
}
|
||||
)
|
||||
assert request.tool_choice == "none"
|
||||
|
||||
# tools key is None
|
||||
request = ChatCompletionRequest.model_validate(
|
||||
{
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"model": "facebook/opt-125m",
|
||||
"tools": None,
|
||||
}
|
||||
)
|
||||
assert request.tool_choice == "none"
|
||||
|
||||
# tools key present but empty
|
||||
request = ChatCompletionRequest.model_validate(
|
||||
{
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"model": "facebook/opt-125m",
|
||||
"tools": [],
|
||||
}
|
||||
)
|
||||
assert request.tool_choice == "none"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tool_choice", ["auto", "required"])
|
||||
def test_chat_completion_request_with_tool_choice_but_no_tools(tool_choice):
|
||||
with pytest.raises(
|
||||
ValueError, match="When using `tool_choice`, `tools` must be set."
|
||||
):
|
||||
ChatCompletionRequest.model_validate(
|
||||
{
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"model": "facebook/opt-125m",
|
||||
"tool_choice": tool_choice,
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="When using `tool_choice`, `tools` must be set."
|
||||
):
|
||||
ChatCompletionRequest.model_validate(
|
||||
{
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"model": "facebook/opt-125m",
|
||||
"tool_choice": tool_choice,
|
||||
"tools": None,
|
||||
}
|
||||
)
|
||||
153
tests/tool_use/test_chat_completions.py
Normal file
153
tests/tool_use/test_chat_completions.py
Normal file
@@ -0,0 +1,153 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
from .utils import (
|
||||
MESSAGES_WITHOUT_TOOLS,
|
||||
WEATHER_TOOL,
|
||||
ServerConfig,
|
||||
ensure_system_prompt,
|
||||
)
|
||||
|
||||
|
||||
# test: make sure chat completions without tools provided work even when tools
|
||||
# are enabled. This makes sure tool call chat templates work, AND that the tool
|
||||
# parser stream processing doesn't change the output of the model.
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_without_tools(
|
||||
client: openai.AsyncOpenAI, server_config: ServerConfig
|
||||
):
|
||||
models = await client.models.list()
|
||||
model_name: str = models.data[0].id
|
||||
chat_completion = await client.chat.completions.create(
|
||||
messages=ensure_system_prompt(MESSAGES_WITHOUT_TOOLS, server_config),
|
||||
temperature=0,
|
||||
max_completion_tokens=150,
|
||||
model=model_name,
|
||||
logprobs=False,
|
||||
)
|
||||
choice = chat_completion.choices[0]
|
||||
stop_reason = chat_completion.choices[0].finish_reason
|
||||
output_text = chat_completion.choices[0].message.content
|
||||
|
||||
# check to make sure we got text
|
||||
assert output_text is not None
|
||||
assert len(output_text) > 0
|
||||
assert stop_reason != "tool_calls"
|
||||
|
||||
# check to make sure no tool calls were returned
|
||||
assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 0
|
||||
|
||||
# make the same request, streaming
|
||||
stream = await client.chat.completions.create(
|
||||
messages=ensure_system_prompt(MESSAGES_WITHOUT_TOOLS, server_config),
|
||||
temperature=0,
|
||||
max_completion_tokens=150,
|
||||
model=model_name,
|
||||
logprobs=False,
|
||||
stream=True,
|
||||
)
|
||||
chunks: list[str] = []
|
||||
finish_reason_count = 0
|
||||
role_sent: bool = False
|
||||
|
||||
# assemble streamed chunks
|
||||
async for chunk in stream:
|
||||
delta = chunk.choices[0].delta
|
||||
|
||||
# make sure the role is assistant
|
||||
if delta.role:
|
||||
assert not role_sent
|
||||
assert delta.role == "assistant"
|
||||
role_sent = True
|
||||
|
||||
if delta.content:
|
||||
chunks.append(delta.content)
|
||||
|
||||
if chunk.choices[0].finish_reason is not None:
|
||||
finish_reason_count += 1
|
||||
assert chunk.choices[0].finish_reason == choice.finish_reason
|
||||
|
||||
# make sure tool call chunks aren't being streamed
|
||||
assert not delta.tool_calls or len(delta.tool_calls) == 0
|
||||
|
||||
# make sure the role was sent, only 1 finish reason was sent, that chunks
|
||||
# were in fact sent, and that the chunks match non-streaming
|
||||
assert role_sent
|
||||
assert finish_reason_count == 1
|
||||
assert len(chunks)
|
||||
assert "".join(chunks) == output_text
|
||||
|
||||
|
||||
# test: conversation with tools enabled and provided that should not invoke
|
||||
# tools, to make sure we can still get normal chat completion responses
|
||||
# and that they won't be parsed as tools
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_with_tools(
|
||||
client: openai.AsyncOpenAI, server_config: ServerConfig
|
||||
):
|
||||
models = await client.models.list()
|
||||
model_name: str = models.data[0].id
|
||||
chat_completion = await client.chat.completions.create(
|
||||
messages=ensure_system_prompt(MESSAGES_WITHOUT_TOOLS, server_config),
|
||||
temperature=0,
|
||||
max_completion_tokens=150,
|
||||
model=model_name,
|
||||
tools=[WEATHER_TOOL],
|
||||
logprobs=False,
|
||||
)
|
||||
choice = chat_completion.choices[0]
|
||||
stop_reason = chat_completion.choices[0].finish_reason
|
||||
output_text = chat_completion.choices[0].message.content
|
||||
|
||||
# check to make sure we got text
|
||||
assert output_text is not None
|
||||
assert stop_reason != "tool_calls"
|
||||
assert len(output_text) > 0
|
||||
|
||||
# check to make sure no tool calls were returned
|
||||
assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 0
|
||||
|
||||
# make the same request, streaming
|
||||
stream = await client.chat.completions.create(
|
||||
messages=ensure_system_prompt(MESSAGES_WITHOUT_TOOLS, server_config),
|
||||
temperature=0,
|
||||
max_completion_tokens=150,
|
||||
model=model_name,
|
||||
logprobs=False,
|
||||
tools=[WEATHER_TOOL],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
chunks: list[str] = []
|
||||
finish_reason_count = 0
|
||||
role_sent: bool = False
|
||||
|
||||
# assemble streamed chunks
|
||||
async for chunk in stream:
|
||||
delta = chunk.choices[0].delta
|
||||
|
||||
# make sure the role is assistant
|
||||
if delta.role:
|
||||
assert delta.role == "assistant"
|
||||
role_sent = True
|
||||
|
||||
if delta.content:
|
||||
chunks.append(delta.content)
|
||||
|
||||
if chunk.choices[0].finish_reason is not None:
|
||||
finish_reason_count += 1
|
||||
|
||||
# make sure tool call chunks aren't being streamed
|
||||
assert not delta.tool_calls or len(delta.tool_calls) == 0
|
||||
|
||||
# make sure the role was sent, only 1 finish reason was sent, that chunks
|
||||
# were in fact sent, and that the chunks match non-streaming
|
||||
assert role_sent
|
||||
assert finish_reason_count == 1
|
||||
assert chunk.choices[0].finish_reason == stop_reason
|
||||
assert chunk.choices[0].finish_reason != "tool_calls"
|
||||
assert len(chunks)
|
||||
assert "".join(chunks) == output_text
|
||||
271
tests/tool_use/test_parallel_tool_calls.py
Normal file
271
tests/tool_use/test_parallel_tool_calls.py
Normal file
@@ -0,0 +1,271 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
from .utils import (
|
||||
MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
|
||||
MESSAGES_WITH_PARALLEL_TOOL_RESPONSE,
|
||||
SEARCH_TOOL,
|
||||
WEATHER_TOOL,
|
||||
ServerConfig,
|
||||
)
|
||||
|
||||
|
||||
# test: getting the model to generate parallel tool calls (streaming/not)
|
||||
# when requested. NOTE that not all models may support this, so some exclusions
|
||||
# may be added in the future. e.g. llama 3.1 models are not designed to support
|
||||
# parallel tool calls.
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_tool_calls(
|
||||
client: openai.AsyncOpenAI, server_config: ServerConfig
|
||||
):
|
||||
if not server_config.get("supports_parallel", True):
|
||||
pytest.skip(
|
||||
"The {} model doesn't support parallel tool calls".format(
|
||||
server_config["model"]
|
||||
)
|
||||
)
|
||||
|
||||
models = await client.models.list()
|
||||
model_name: str = models.data[0].id
|
||||
chat_completion = await client.chat.completions.create(
|
||||
messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
|
||||
temperature=0,
|
||||
max_completion_tokens=200,
|
||||
model=model_name,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
logprobs=False,
|
||||
)
|
||||
|
||||
choice = chat_completion.choices[0]
|
||||
stop_reason = chat_completion.choices[0].finish_reason
|
||||
non_streamed_tool_calls = chat_completion.choices[0].message.tool_calls
|
||||
|
||||
# make sure 2 tool calls are present
|
||||
assert choice.message.role == "assistant"
|
||||
assert non_streamed_tool_calls is not None
|
||||
assert len(non_streamed_tool_calls) == 2
|
||||
|
||||
for tool_call in non_streamed_tool_calls:
|
||||
# make sure the tool includes a function and ID
|
||||
assert tool_call.type == "function"
|
||||
assert tool_call.function is not None
|
||||
assert isinstance(tool_call.id, str)
|
||||
assert len(tool_call.id) >= 9
|
||||
|
||||
# make sure the weather tool was called correctly
|
||||
assert tool_call.function.name == WEATHER_TOOL["function"]["name"]
|
||||
assert isinstance(tool_call.function.arguments, str)
|
||||
|
||||
parsed_arguments = json.loads(tool_call.function.arguments)
|
||||
assert isinstance(parsed_arguments, dict)
|
||||
assert isinstance(parsed_arguments.get("city"), str)
|
||||
assert isinstance(parsed_arguments.get("state"), str)
|
||||
|
||||
assert stop_reason == "tool_calls"
|
||||
|
||||
# make the same request, streaming
|
||||
stream = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
|
||||
temperature=0,
|
||||
max_completion_tokens=200,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
logprobs=False,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
role_name: str | None = None
|
||||
finish_reason_count: int = 0
|
||||
|
||||
tool_call_names: list[str] = []
|
||||
tool_call_args: list[str] = []
|
||||
tool_call_idx: int = -1
|
||||
tool_call_id_count: int = 0
|
||||
|
||||
async for chunk in stream:
|
||||
# if there's a finish reason make sure it's tools
|
||||
if chunk.choices[0].finish_reason:
|
||||
finish_reason_count += 1
|
||||
assert chunk.choices[0].finish_reason == "tool_calls"
|
||||
|
||||
# if a role is being streamed make sure it wasn't already set to
|
||||
# something else
|
||||
if chunk.choices[0].delta.role:
|
||||
assert not role_name or role_name == "assistant"
|
||||
role_name = "assistant"
|
||||
|
||||
# if a tool call is streamed make sure there's exactly one
|
||||
# (based on the request parameters
|
||||
streamed_tool_calls = chunk.choices[0].delta.tool_calls
|
||||
|
||||
if streamed_tool_calls and len(streamed_tool_calls) > 0:
|
||||
# make sure only one diff is present - correct even for parallel
|
||||
assert len(streamed_tool_calls) == 1
|
||||
tool_call = streamed_tool_calls[0]
|
||||
|
||||
# if a new tool is being called, set up empty arguments
|
||||
if tool_call.index != tool_call_idx:
|
||||
tool_call_idx = tool_call.index
|
||||
tool_call_args.append("")
|
||||
|
||||
# if a tool call ID is streamed, make sure one hasn't been already
|
||||
if tool_call.id:
|
||||
tool_call_id_count += 1
|
||||
assert isinstance(tool_call.id, str) and (len(tool_call.id) >= 9)
|
||||
|
||||
# if parts of the function start being streamed
|
||||
if tool_call.function:
|
||||
# if the function name is defined, set it. it should be streamed
|
||||
# IN ENTIRETY, exactly one time.
|
||||
if tool_call.function.name:
|
||||
assert isinstance(tool_call.function.name, str)
|
||||
tool_call_names.append(tool_call.function.name)
|
||||
|
||||
if tool_call.function.arguments:
|
||||
# make sure they're a string and then add them to the list
|
||||
assert isinstance(tool_call.function.arguments, str)
|
||||
|
||||
tool_call_args[tool_call.index] += tool_call.function.arguments
|
||||
|
||||
assert finish_reason_count == 1
|
||||
assert role_name == "assistant"
|
||||
|
||||
assert len(non_streamed_tool_calls) == len(tool_call_names) == len(tool_call_args)
|
||||
|
||||
for i in range(2):
|
||||
assert non_streamed_tool_calls[i].function.name == tool_call_names[i]
|
||||
streamed_args = json.loads(tool_call_args[i])
|
||||
non_streamed_args = json.loads(non_streamed_tool_calls[i].function.arguments)
|
||||
assert streamed_args == non_streamed_args
|
||||
|
||||
|
||||
# test: providing parallel tool calls back to the model to get a response
|
||||
# (streaming/not)
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_tool_calls_with_results(
|
||||
client: openai.AsyncOpenAI, server_config: ServerConfig
|
||||
):
|
||||
if not server_config.get("supports_parallel", True):
|
||||
pytest.skip(
|
||||
"The {} model doesn't support parallel tool calls".format(
|
||||
server_config["model"]
|
||||
)
|
||||
)
|
||||
|
||||
models = await client.models.list()
|
||||
model_name: str = models.data[0].id
|
||||
chat_completion = await client.chat.completions.create(
|
||||
messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE,
|
||||
temperature=0,
|
||||
max_completion_tokens=200,
|
||||
model=model_name,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
logprobs=False,
|
||||
)
|
||||
|
||||
choice = chat_completion.choices[0]
|
||||
|
||||
assert choice.finish_reason != "tool_calls" # "stop" or "length"
|
||||
assert choice.message.role == "assistant"
|
||||
assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 0
|
||||
assert choice.message.content is not None
|
||||
assert "98" in choice.message.content # Dallas temp in tool response
|
||||
assert "78" in choice.message.content # Orlando temp in tool response
|
||||
|
||||
stream = await client.chat.completions.create(
|
||||
messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE,
|
||||
temperature=0,
|
||||
max_completion_tokens=200,
|
||||
model=model_name,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
logprobs=False,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
chunks: list[str] = []
|
||||
finish_reason_count = 0
|
||||
role_sent: bool = False
|
||||
|
||||
async for chunk in stream:
|
||||
delta = chunk.choices[0].delta
|
||||
|
||||
if delta.role:
|
||||
assert not role_sent
|
||||
assert delta.role == "assistant"
|
||||
role_sent = True
|
||||
|
||||
if delta.content:
|
||||
chunks.append(delta.content)
|
||||
|
||||
if chunk.choices[0].finish_reason is not None:
|
||||
finish_reason_count += 1
|
||||
assert chunk.choices[0].finish_reason == choice.finish_reason
|
||||
|
||||
assert not delta.tool_calls or len(delta.tool_calls) == 0
|
||||
|
||||
assert role_sent
|
||||
assert finish_reason_count == 1
|
||||
assert len(chunks)
|
||||
assert "".join(chunks) == choice.message.content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_tool_calls_false(client: openai.AsyncOpenAI):
|
||||
"""
|
||||
Ensure only one tool call is returned when parallel_tool_calls is False.
|
||||
"""
|
||||
|
||||
models = await client.models.list()
|
||||
model_name: str = models.data[0].id
|
||||
chat_completion = await client.chat.completions.create(
|
||||
messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
|
||||
temperature=0,
|
||||
max_completion_tokens=200,
|
||||
model=model_name,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
logprobs=False,
|
||||
parallel_tool_calls=False,
|
||||
)
|
||||
|
||||
stop_reason = chat_completion.choices[0].finish_reason
|
||||
non_streamed_tool_calls = chat_completion.choices[0].message.tool_calls
|
||||
|
||||
# make sure only 1 tool call is present
|
||||
assert len(non_streamed_tool_calls) == 1
|
||||
assert stop_reason == "tool_calls"
|
||||
|
||||
# make the same request, streaming
|
||||
stream = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
|
||||
temperature=0,
|
||||
max_completion_tokens=200,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
logprobs=False,
|
||||
parallel_tool_calls=False,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
finish_reason_count: int = 0
|
||||
tool_call_id_count: int = 0
|
||||
|
||||
async for chunk in stream:
|
||||
# if there's a finish reason make sure it's tools
|
||||
if chunk.choices[0].finish_reason:
|
||||
finish_reason_count += 1
|
||||
assert chunk.choices[0].finish_reason == "tool_calls"
|
||||
|
||||
streamed_tool_calls = chunk.choices[0].delta.tool_calls
|
||||
if streamed_tool_calls and len(streamed_tool_calls) > 0:
|
||||
tool_call = streamed_tool_calls[0]
|
||||
if tool_call.id:
|
||||
tool_call_id_count += 1
|
||||
|
||||
# make sure only 1 streaming tool call is present
|
||||
assert tool_call_id_count == 1
|
||||
assert finish_reason_count == 1
|
||||
201
tests/tool_use/test_tool_calls.py
Normal file
201
tests/tool_use/test_tool_calls.py
Normal file
@@ -0,0 +1,201 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
from .utils import (
|
||||
MESSAGES_ASKING_FOR_TOOLS,
|
||||
MESSAGES_WITH_TOOL_RESPONSE,
|
||||
SEARCH_TOOL,
|
||||
WEATHER_TOOL,
|
||||
)
|
||||
|
||||
|
||||
# test: request a chat completion that should return tool calls, so we know they
|
||||
# are parsable
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
|
||||
models = await client.models.list()
|
||||
model_name: str = models.data[0].id
|
||||
chat_completion = await client.chat.completions.create(
|
||||
messages=MESSAGES_ASKING_FOR_TOOLS,
|
||||
temperature=0,
|
||||
max_completion_tokens=100,
|
||||
model=model_name,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
logprobs=False,
|
||||
)
|
||||
|
||||
choice = chat_completion.choices[0]
|
||||
stop_reason = chat_completion.choices[0].finish_reason
|
||||
tool_calls = chat_completion.choices[0].message.tool_calls
|
||||
|
||||
# make sure a tool call is present
|
||||
assert choice.message.role == "assistant"
|
||||
assert tool_calls is not None
|
||||
assert len(tool_calls) == 1
|
||||
assert tool_calls[0].type == "function"
|
||||
assert tool_calls[0].function is not None
|
||||
assert isinstance(tool_calls[0].id, str)
|
||||
assert len(tool_calls[0].id) >= 9
|
||||
|
||||
# make sure the weather tool was called (classic example) with arguments
|
||||
assert tool_calls[0].function.name == WEATHER_TOOL["function"]["name"]
|
||||
assert tool_calls[0].function.arguments is not None
|
||||
assert isinstance(tool_calls[0].function.arguments, str)
|
||||
|
||||
# make sure the arguments parse properly
|
||||
parsed_arguments = json.loads(tool_calls[0].function.arguments)
|
||||
assert isinstance(parsed_arguments, dict)
|
||||
assert isinstance(parsed_arguments.get("city"), str)
|
||||
assert isinstance(parsed_arguments.get("state"), str)
|
||||
assert parsed_arguments.get("city") == "Dallas"
|
||||
assert parsed_arguments.get("state") == "TX"
|
||||
|
||||
assert stop_reason == "tool_calls"
|
||||
|
||||
function_name: str | None = None
|
||||
function_args_str: str = ""
|
||||
tool_call_id: str | None = None
|
||||
role_name: str | None = None
|
||||
finish_reason_count: int = 0
|
||||
|
||||
# make the same request, streaming
|
||||
stream = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=MESSAGES_ASKING_FOR_TOOLS,
|
||||
temperature=0,
|
||||
max_completion_tokens=100,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
logprobs=False,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
async for chunk in stream:
|
||||
assert chunk.choices[0].index == 0
|
||||
|
||||
if chunk.choices[0].finish_reason:
|
||||
finish_reason_count += 1
|
||||
assert chunk.choices[0].finish_reason == "tool_calls"
|
||||
|
||||
# if a role is being streamed make sure it wasn't already set to
|
||||
# something else
|
||||
if chunk.choices[0].delta.role:
|
||||
assert not role_name or role_name == "assistant"
|
||||
role_name = "assistant"
|
||||
|
||||
# if a tool call is streamed make sure there's exactly one
|
||||
# (based on the request parameters
|
||||
streamed_tool_calls = chunk.choices[0].delta.tool_calls
|
||||
|
||||
if streamed_tool_calls and len(streamed_tool_calls) > 0:
|
||||
assert len(streamed_tool_calls) == 1
|
||||
tool_call = streamed_tool_calls[0]
|
||||
|
||||
# if a tool call ID is streamed, make sure one hasn't been already
|
||||
if tool_call.id:
|
||||
assert not tool_call_id
|
||||
tool_call_id = tool_call.id
|
||||
|
||||
# if parts of the function start being streamed
|
||||
if tool_call.function:
|
||||
# if the function name is defined, set it. it should be streamed
|
||||
# IN ENTIRETY, exactly one time.
|
||||
if tool_call.function.name:
|
||||
assert function_name is None
|
||||
assert isinstance(tool_call.function.name, str)
|
||||
function_name = tool_call.function.name
|
||||
if tool_call.function.arguments:
|
||||
assert isinstance(tool_call.function.arguments, str)
|
||||
function_args_str += tool_call.function.arguments
|
||||
|
||||
assert finish_reason_count == 1
|
||||
assert role_name == "assistant"
|
||||
assert isinstance(tool_call_id, str) and (len(tool_call_id) >= 9)
|
||||
|
||||
# validate the name and arguments
|
||||
assert function_name == WEATHER_TOOL["function"]["name"]
|
||||
assert function_name == tool_calls[0].function.name
|
||||
assert isinstance(function_args_str, str)
|
||||
|
||||
# validate arguments
|
||||
streamed_args = json.loads(function_args_str)
|
||||
assert isinstance(streamed_args, dict)
|
||||
assert isinstance(streamed_args.get("city"), str)
|
||||
assert isinstance(streamed_args.get("state"), str)
|
||||
assert streamed_args.get("city") == "Dallas"
|
||||
assert streamed_args.get("state") == "TX"
|
||||
|
||||
# make sure everything matches non-streaming except for ID
|
||||
assert function_name == tool_calls[0].function.name
|
||||
assert choice.message.role == role_name
|
||||
assert choice.message.tool_calls[0].function.name == function_name
|
||||
|
||||
# compare streamed with non-streamed args dict-wise, not string-wise
|
||||
# because character-to-character comparison might not work e.g. the tool
|
||||
# call parser adding extra spaces or something like that. we care about the
|
||||
# dicts matching not byte-wise match
|
||||
assert parsed_arguments == streamed_args
|
||||
|
||||
|
||||
# test: providing tools and results back to model to get a non-tool response
|
||||
# (streaming/not)
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_with_results(client: openai.AsyncOpenAI):
|
||||
models = await client.models.list()
|
||||
model_name: str = models.data[0].id
|
||||
chat_completion = await client.chat.completions.create(
|
||||
messages=MESSAGES_WITH_TOOL_RESPONSE,
|
||||
temperature=0,
|
||||
max_completion_tokens=100,
|
||||
model=model_name,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
logprobs=False,
|
||||
)
|
||||
|
||||
choice = chat_completion.choices[0]
|
||||
|
||||
assert choice.finish_reason != "tool_calls" # "stop" or "length"
|
||||
assert choice.message.role == "assistant"
|
||||
assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 0
|
||||
assert choice.message.content is not None
|
||||
assert "98" in choice.message.content # the temperature from the response
|
||||
|
||||
stream = await client.chat.completions.create(
|
||||
messages=MESSAGES_WITH_TOOL_RESPONSE,
|
||||
temperature=0,
|
||||
max_completion_tokens=100,
|
||||
model=model_name,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
logprobs=False,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
chunks: list[str] = []
|
||||
finish_reason_count = 0
|
||||
role_sent: bool = False
|
||||
|
||||
async for chunk in stream:
|
||||
delta = chunk.choices[0].delta
|
||||
|
||||
if delta.role:
|
||||
assert not role_sent
|
||||
assert delta.role == "assistant"
|
||||
role_sent = True
|
||||
|
||||
if delta.content:
|
||||
chunks.append(delta.content)
|
||||
|
||||
if chunk.choices[0].finish_reason is not None:
|
||||
finish_reason_count += 1
|
||||
assert chunk.choices[0].finish_reason == choice.finish_reason
|
||||
|
||||
assert not delta.tool_calls or len(delta.tool_calls) == 0
|
||||
|
||||
assert role_sent
|
||||
assert finish_reason_count == 1
|
||||
assert len(chunks)
|
||||
assert "".join(chunks) == choice.message.content
|
||||
330
tests/tool_use/test_tool_choice_required.py
Normal file
330
tests/tool_use/test_tool_choice_required.py
Normal file
@@ -0,0 +1,330 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import json
|
||||
from copy import deepcopy
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import regex as re
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionToolsParam,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.tool_parsers.utils import get_json_schema_from_tools
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
EXAMPLE_TOOLS = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The city to find the weather for"
|
||||
", e.g. 'San Francisco'",
|
||||
},
|
||||
},
|
||||
"required": ["city"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
"strict": True,
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_forecast",
|
||||
"description": "Get the weather forecast for a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The city to get the forecast for, e.g. "
|
||||
"'New York'",
|
||||
},
|
||||
"days": {
|
||||
"type": "integer",
|
||||
"description": "Number of days to get the forecast for (1-7)",
|
||||
},
|
||||
},
|
||||
"required": ["city", "days"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
"strict": True,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def _compile_and_check(
|
||||
tools: list[ChatCompletionToolsParam], sample_output, should_match: bool
|
||||
):
|
||||
# self = MagicMock(tool_choice="required", tools=tools)
|
||||
# schema = ChatCompletionRequest._get_json_schema_from_tool(self)
|
||||
schema = get_json_schema_from_tools(tools=tools, tool_choice="required")
|
||||
assert isinstance(schema, dict)
|
||||
|
||||
# use build_regex_from_schema used in JSONLogitsProcessor to create Guide
|
||||
from outlines_core.json_schema import build_regex_from_schema
|
||||
|
||||
regex = build_regex_from_schema(json.dumps(schema))
|
||||
compiled = re.compile(regex)
|
||||
matches = compiled.fullmatch(json.dumps(sample_output)) is not None
|
||||
|
||||
assert matches == should_match
|
||||
|
||||
|
||||
VALID_TOOL_OUTPUTS = [
|
||||
([{"name": "get_current_weather", "parameters": {"city": "Vienna"}}], True),
|
||||
(
|
||||
[
|
||||
{"name": "get_current_weather", "parameters": {"city": "Vienna"}},
|
||||
{"name": "get_current_weather", "parameters": {"city": "Berlin"}},
|
||||
],
|
||||
True,
|
||||
),
|
||||
([{"name": "get_forecast", "parameters": {"city": "Vienna", "days": 7}}], True),
|
||||
(
|
||||
[
|
||||
{"name": "get_forecast", "parameters": {"city": "Vienna", "days": 7}},
|
||||
{"name": "get_current_weather", "parameters": {"city": "Vienna"}},
|
||||
],
|
||||
True,
|
||||
),
|
||||
(
|
||||
[
|
||||
{"name": "get_forecast", "parameters": {"city": "Vienna", "days": 7}},
|
||||
{"name": "get_current_weather", "parameters": {"city": "Vienna"}},
|
||||
{"name": "get_forecast", "parameters": {"city": "Berlin", "days": 7}},
|
||||
{"name": "get_current_weather", "parameters": {"city": "Berlin"}},
|
||||
],
|
||||
True,
|
||||
),
|
||||
]
|
||||
|
||||
VALID_TOOLS = [t[0] for t in VALID_TOOL_OUTPUTS]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sample_output, should_match",
|
||||
VALID_TOOL_OUTPUTS
|
||||
+ [
|
||||
(None, False),
|
||||
([], False), # empty list cannot be generated
|
||||
({}, False), # empty object cannot be generated
|
||||
([{}], False), # list with empty object cannot be generated
|
||||
(
|
||||
[
|
||||
{ # function without required parameters cannot be generated
|
||||
"name": "get_current_weather"
|
||||
}
|
||||
],
|
||||
False,
|
||||
),
|
||||
(
|
||||
[
|
||||
{ # function without required parameters cannot be generated
|
||||
"name": "get_current_weather",
|
||||
"parameters": {},
|
||||
}
|
||||
],
|
||||
False,
|
||||
),
|
||||
(
|
||||
[
|
||||
{ # function without required parameters cannot be generated
|
||||
"name": "get_current_weather",
|
||||
"parameters": None,
|
||||
}
|
||||
],
|
||||
False,
|
||||
),
|
||||
(
|
||||
{ # tool call without lists cannot be generated
|
||||
"name": "get_current_weather",
|
||||
"parameters": {"city": "Vienna"},
|
||||
},
|
||||
False,
|
||||
),
|
||||
(
|
||||
[
|
||||
{ # tool call with extra parameters cannot be generated
|
||||
"name": "get_current_weather",
|
||||
"parameters": {"city": "Vienna", "extra": "value"},
|
||||
}
|
||||
],
|
||||
False,
|
||||
),
|
||||
(
|
||||
[
|
||||
{ # tool call where parameters are first cannot be generated
|
||||
"parameters": {"city": "Vienna"},
|
||||
"name": "get_current_weather",
|
||||
}
|
||||
],
|
||||
False,
|
||||
),
|
||||
(
|
||||
[
|
||||
{ # tool call without all required parameters cannot be generated
|
||||
"name": "get_forecast",
|
||||
"parameters": {"city": "Vienna"},
|
||||
}
|
||||
],
|
||||
False,
|
||||
),
|
||||
( # tool call with incorrect name/parameters cannot be generated
|
||||
[{"name": "get_weather", "parameters": {"city": "Vienna", "days": 7}}],
|
||||
False,
|
||||
),
|
||||
( # tool call with both valid and empty function cannot be generated
|
||||
[{"name": "get_current_weather", "parameters": {"city": "Vienna"}}, {}],
|
||||
False,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_structured_outputs_json(sample_output, should_match):
|
||||
_compile_and_check(
|
||||
tools=TypeAdapter(list[ChatCompletionToolsParam]).validate_python(
|
||||
EXAMPLE_TOOLS
|
||||
),
|
||||
sample_output=sample_output,
|
||||
should_match=should_match,
|
||||
)
|
||||
|
||||
|
||||
def update_parameters_none(tool: ChatCompletionToolsParam) -> ChatCompletionToolsParam:
|
||||
tool.function.parameters = None
|
||||
return tool
|
||||
|
||||
|
||||
def update_parameters_empty_dict(
|
||||
tool: ChatCompletionToolsParam,
|
||||
) -> ChatCompletionToolsParam:
|
||||
tool.function.parameters = {}
|
||||
return tool
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sample_output, should_match",
|
||||
[
|
||||
(None, False),
|
||||
([], False), # empty list cannot be generated
|
||||
({}, False), # empty object cannot be generated
|
||||
([{}], False), # list with empty object cannot be generated
|
||||
(
|
||||
[
|
||||
{ # function without required parameters cannot be generated
|
||||
"name": "get_current_weather"
|
||||
}
|
||||
],
|
||||
False,
|
||||
),
|
||||
(
|
||||
[
|
||||
{ # function without required parameters cannot be generated
|
||||
"name": "get_current_weather",
|
||||
"parameters": None,
|
||||
}
|
||||
],
|
||||
False,
|
||||
),
|
||||
(
|
||||
[
|
||||
{ # function with extra parameters cannot be generated
|
||||
"name": "get_current_weather",
|
||||
"parameters": {"extra": "value"},
|
||||
}
|
||||
],
|
||||
False,
|
||||
),
|
||||
(
|
||||
[
|
||||
{ # only function with empty parameters object is valid
|
||||
"name": "get_current_weather",
|
||||
"parameters": {},
|
||||
}
|
||||
],
|
||||
True,
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"update_parameters", [update_parameters_none, update_parameters_empty_dict]
|
||||
)
|
||||
def test_structured_outputs_json_without_parameters(
|
||||
sample_output, should_match, update_parameters
|
||||
):
|
||||
updated_tools = [deepcopy(EXAMPLE_TOOLS[0])]
|
||||
tools = TypeAdapter(list[ChatCompletionToolsParam]).validate_python(updated_tools)
|
||||
tools = list(map(update_parameters, tools))
|
||||
assert all(
|
||||
[
|
||||
tool.function.parameters is None or tool.function.parameters == {}
|
||||
for tool in tools
|
||||
]
|
||||
)
|
||||
_compile_and_check(
|
||||
tools=tools, sample_output=sample_output, should_match=should_match
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("output", VALID_TOOLS)
|
||||
@pytest.mark.parametrize("empty_params", [False, True])
|
||||
@pytest.mark.parametrize("delta_len", [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
|
||||
def test_streaming_output_valid(output, empty_params, delta_len):
|
||||
self = MagicMock()
|
||||
|
||||
output = deepcopy(output)
|
||||
if empty_params:
|
||||
output = [{"name": o["name"], "parameters": {}} for o in output]
|
||||
output_json = json.dumps(output)
|
||||
|
||||
previous_text = ""
|
||||
function_name_returned = False
|
||||
messages = []
|
||||
for i in range(0, len(output_json), delta_len):
|
||||
delta_text = output_json[i : i + delta_len]
|
||||
current_text = previous_text + delta_text
|
||||
|
||||
delta_message, function_name_returned = (
|
||||
OpenAIServingChat.extract_tool_call_required_streaming(
|
||||
self,
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
function_name_returned=function_name_returned,
|
||||
)
|
||||
)
|
||||
|
||||
if delta_message:
|
||||
messages.append(delta_message)
|
||||
|
||||
previous_text = current_text
|
||||
|
||||
assert len(messages) > 0
|
||||
combined_messages = "["
|
||||
for message in messages:
|
||||
if message.tool_calls[0].function.name:
|
||||
if len(combined_messages) > 1:
|
||||
combined_messages += "},"
|
||||
|
||||
combined_messages += (
|
||||
'{"name": "'
|
||||
+ message.tool_calls[0].function.name
|
||||
+ '", "parameters": '
|
||||
+ message.tool_calls[0].function.arguments
|
||||
)
|
||||
else:
|
||||
combined_messages += message.tool_calls[0].function.arguments
|
||||
combined_messages += "}]"
|
||||
assert json.loads(combined_messages) == output
|
||||
assert json.dumps(json.loads(combined_messages)) == output_json
|
||||
375
tests/tool_use/utils.py
Normal file
375
tests/tool_use/utils.py
Normal file
@@ -0,0 +1,375 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolParam
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from tests.utils import VLLM_PATH
|
||||
|
||||
|
||||
class ServerConfig(TypedDict, total=False):
|
||||
model: str
|
||||
arguments: list[str]
|
||||
system_prompt: str | None
|
||||
supports_parallel: bool | None
|
||||
supports_rocm: bool | None
|
||||
extended: bool | None # tests do not run in CI automatically
|
||||
|
||||
|
||||
def patch_system_prompt(
|
||||
messages: list[dict[str, Any]], system_prompt: str
|
||||
) -> list[dict[str, Any]]:
|
||||
new_messages = deepcopy(messages)
|
||||
if new_messages[0]["role"] == "system":
|
||||
new_messages[0]["content"] = system_prompt
|
||||
else:
|
||||
new_messages.insert(0, {"role": "system", "content": system_prompt})
|
||||
return new_messages
|
||||
|
||||
|
||||
def ensure_system_prompt(
|
||||
messages: list[dict[str, Any]], config: ServerConfig
|
||||
) -> list[dict[str, Any]]:
|
||||
prompt = config.get("system_prompt")
|
||||
if prompt:
|
||||
return patch_system_prompt(messages, prompt)
|
||||
else:
|
||||
return messages
|
||||
|
||||
|
||||
# universal args for all models go here. also good if you need to test locally
|
||||
# and change type or KV cache quantization or something.
|
||||
ARGS: list[str] = [
|
||||
"--enable-auto-tool-choice",
|
||||
"--max-model-len",
|
||||
"1024",
|
||||
"--max-num-seqs",
|
||||
"256",
|
||||
]
|
||||
|
||||
CONFIGS: dict[str, ServerConfig] = {
|
||||
"hermes": {
|
||||
"model": "NousResearch/Hermes-3-Llama-3.1-8B",
|
||||
"arguments": [
|
||||
"--enforce-eager",
|
||||
"--no-enable-prefix-caching",
|
||||
"--tool-call-parser",
|
||||
"hermes",
|
||||
"--chat-template",
|
||||
str(VLLM_PATH / "examples/tool_chat_template_hermes.jinja"),
|
||||
],
|
||||
"system_prompt": "You are a helpful assistant with access to tools. If a tool"
|
||||
" that you have would be helpful to answer a user query, "
|
||||
"call the tool. Otherwise, answer the user's query directly "
|
||||
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
|
||||
"to the user's question - just respond to it normally.",
|
||||
},
|
||||
"llama": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"arguments": [
|
||||
"--enforce-eager",
|
||||
"--no-enable-prefix-caching",
|
||||
"--tool-call-parser",
|
||||
"llama3_json",
|
||||
"--chat-template",
|
||||
str(VLLM_PATH / "examples/tool_chat_template_llama3.1_json.jinja"),
|
||||
],
|
||||
"supports_parallel": False,
|
||||
},
|
||||
"llama3.2": {
|
||||
"model": "meta-llama/Llama-3.2-3B-Instruct",
|
||||
"arguments": [
|
||||
"--enforce-eager",
|
||||
"--no-enable-prefix-caching",
|
||||
"--tool-call-parser",
|
||||
"llama3_json",
|
||||
"--chat-template",
|
||||
str(VLLM_PATH / "examples/tool_chat_template_llama3.2_json.jinja"),
|
||||
],
|
||||
"supports_parallel": False,
|
||||
},
|
||||
"llama4": {
|
||||
"model": "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"arguments": [
|
||||
"--enforce-eager",
|
||||
"--no-enable-prefix-caching",
|
||||
"--tool-call-parser",
|
||||
"llama4_pythonic",
|
||||
"--chat-template",
|
||||
str(VLLM_PATH / "examples/tool_chat_template_llama4_pythonic.jinja"),
|
||||
"-tp",
|
||||
"4",
|
||||
],
|
||||
"supports_parallel": False,
|
||||
"extended": True,
|
||||
},
|
||||
"llama4_json": {
|
||||
"model": "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"arguments": [
|
||||
"--enforce-eager",
|
||||
"--no-enable-prefix-caching",
|
||||
"-tp",
|
||||
"4",
|
||||
"--distributed-executor-backend",
|
||||
"mp",
|
||||
"--tool-call-parser",
|
||||
"llama4_json",
|
||||
"--chat-template",
|
||||
str(VLLM_PATH / "examples/tool_chat_template_llama4_json.jinja"),
|
||||
],
|
||||
"supports_parallel": True,
|
||||
"extended": True,
|
||||
},
|
||||
"mistral-7b": {
|
||||
"model": "mistralai/Mistral-7B-Instruct-v0.3",
|
||||
"arguments": [
|
||||
"--enforce-eager",
|
||||
"--no-enable-prefix-caching",
|
||||
"--tokenizer_mode",
|
||||
"hf",
|
||||
"--load_format",
|
||||
"hf",
|
||||
"--config_format",
|
||||
"hf",
|
||||
"--tool-call-parser",
|
||||
"mistral",
|
||||
"--chat-template",
|
||||
str(VLLM_PATH / "examples/tool_chat_template_mistral.jinja"),
|
||||
'--ignore-patterns="consolidated.safetensors"',
|
||||
],
|
||||
"system_prompt": "You are a helpful assistant with access to tools. If a tool"
|
||||
" that you have would be helpful to answer a user query, "
|
||||
"call the tool. Otherwise, answer the user's query directly "
|
||||
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
|
||||
"to the user's question - just respond to it normally.",
|
||||
"supports_parallel": True,
|
||||
},
|
||||
"mistral-small-3.2": {
|
||||
"model": "mistralai/Mistral-Small-3.2-24B-Instruct-2506",
|
||||
"arguments": [
|
||||
"--enforce-eager",
|
||||
"--no-enable-prefix-caching",
|
||||
"--tool-call-parser",
|
||||
"mistral",
|
||||
"--tokenizer-mode",
|
||||
"mistral",
|
||||
"--config-format",
|
||||
"mistral",
|
||||
"--load-format",
|
||||
"mistral",
|
||||
"--tensor-parallel-size",
|
||||
"4",
|
||||
'--ignore-patterns="consolidated.safetensors"',
|
||||
],
|
||||
"system_prompt": "You are a helpful assistant with access to tools. If a tool"
|
||||
" that you have would be helpful to answer a user query, "
|
||||
"call the tool. Otherwise, answer the user's query directly "
|
||||
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
|
||||
"to the user's question - just respond to it normally.",
|
||||
"supports_parallel": True,
|
||||
"extended": True,
|
||||
},
|
||||
# FIXME: This test currently fails, need to debug why.
|
||||
# "granite20b": {
|
||||
# "model": "mbayser/granite-20b-functioncalling-FP8-KV",
|
||||
# "arguments": [
|
||||
# "--tool-call-parser",
|
||||
# "granite-20b-fc",
|
||||
# "--chat-template",
|
||||
# str(VLLM_PATH / "examples/tool_chat_template_granite_20b_fc.jinja"),
|
||||
# "--max_num_seqs",
|
||||
# "1",
|
||||
# "--enforce-eager",
|
||||
# "--cpu-offload-gb",
|
||||
# "20",
|
||||
# ],
|
||||
# "supports_parallel": False,
|
||||
# "supports_rocm": False,
|
||||
# },
|
||||
"granite-3.0-8b": {
|
||||
"model": "ibm-granite/granite-3.0-8b-instruct",
|
||||
"arguments": [
|
||||
"--enforce-eager",
|
||||
"--no-enable-prefix-caching",
|
||||
"--tool-call-parser",
|
||||
"granite",
|
||||
"--chat-template",
|
||||
str(VLLM_PATH / "examples/tool_chat_template_granite.jinja"),
|
||||
],
|
||||
},
|
||||
"granite-3.1-8b": {
|
||||
"model": "ibm-granite/granite-3.1-8b-instruct",
|
||||
"arguments": [
|
||||
"--enforce-eager",
|
||||
"--no-enable-prefix-caching",
|
||||
"--tool-call-parser",
|
||||
"granite",
|
||||
],
|
||||
"supports_parallel": True,
|
||||
},
|
||||
"internlm": {
|
||||
"model": "internlm/internlm2_5-7b-chat",
|
||||
"arguments": [
|
||||
"--enforce-eager",
|
||||
"--no-enable-prefix-caching",
|
||||
"--tool-call-parser",
|
||||
"internlm",
|
||||
"--chat-template",
|
||||
str(VLLM_PATH / "examples/tool_chat_template_internlm2_tool.jinja"),
|
||||
"--trust_remote_code",
|
||||
],
|
||||
"supports_parallel": False,
|
||||
},
|
||||
"toolACE": {
|
||||
"model": "Team-ACE/ToolACE-8B",
|
||||
"arguments": [
|
||||
"--enforce-eager",
|
||||
"--no-enable-prefix-caching",
|
||||
"--tool-call-parser",
|
||||
"pythonic",
|
||||
"--chat-template",
|
||||
str(VLLM_PATH / "examples/tool_chat_template_toolace.jinja"),
|
||||
],
|
||||
"supports_parallel": True,
|
||||
},
|
||||
}
|
||||
|
||||
WEATHER_TOOL: ChatCompletionToolParam = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The city to find the weather for, "
|
||||
"e.g. 'San Francisco'",
|
||||
},
|
||||
"state": {
|
||||
"type": "string",
|
||||
"description": "must the two-letter abbreviation for the state "
|
||||
"that the city is in, e.g. 'CA' which would "
|
||||
"mean 'California'",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description": "The unit to fetch the temperature in",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
SEARCH_TOOL: ChatCompletionToolParam = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "web_search",
|
||||
"description": "Search the internet and get a summary of the top "
|
||||
"10 webpages. Should only be used if you don't know "
|
||||
"the answer to a user query, and the results are likely"
|
||||
"to be able to be found with a web search",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"search_term": {
|
||||
"type": "string",
|
||||
"description": "The term to use in the search. This should"
|
||||
"ideally be keywords to search for, not a"
|
||||
"natural-language question",
|
||||
}
|
||||
},
|
||||
"required": ["search_term"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
MESSAGES_WITHOUT_TOOLS: list[ChatCompletionMessageParam] = [
|
||||
{"role": "user", "content": "Hi! How are you?"},
|
||||
{"role": "assistant", "content": "I'm doing great! How can I assist you?"},
|
||||
{"role": "user", "content": "Can you tell me a joke please?"},
|
||||
]
|
||||
|
||||
MESSAGES_ASKING_FOR_TOOLS: list[ChatCompletionMessageParam] = [
|
||||
{"role": "user", "content": "What is the weather in Dallas, Texas in Fahrenheit?"}
|
||||
]
|
||||
|
||||
MESSAGES_WITH_TOOL_RESPONSE: list[ChatCompletionMessageParam] = [
|
||||
{"role": "user", "content": "What is the weather in Dallas, Texas in Fahrenheit?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": WEATHER_TOOL["function"]["name"],
|
||||
"arguments": '{"city": "Dallas", "state": "TX", '
|
||||
'"unit": "fahrenheit"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295",
|
||||
"content": "The weather in Dallas is 98 degrees fahrenheit, with partly"
|
||||
"cloudy skies and a low chance of rain.",
|
||||
},
|
||||
]
|
||||
|
||||
MESSAGES_ASKING_FOR_PARALLEL_TOOLS: list[ChatCompletionMessageParam] = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather in Dallas, Texas and Orlando, Florida in "
|
||||
"Fahrenheit?",
|
||||
}
|
||||
]
|
||||
|
||||
MESSAGES_WITH_PARALLEL_TOOL_RESPONSE: list[ChatCompletionMessageParam] = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather in Dallas, Texas and Orlando, Florida in "
|
||||
"Fahrenheit?",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": WEATHER_TOOL["function"]["name"],
|
||||
"arguments": '{"city": "Dallas", "state": "TX", '
|
||||
'"unit": "fahrenheit"}',
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "chatcmpl-tool-d027061e1bd21cda48bee7da829c1f5b",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": WEATHER_TOOL["function"]["name"],
|
||||
"arguments": '{"city": "Orlando", "state": "Fl", '
|
||||
'"unit": "fahrenheit"}',
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295",
|
||||
"content": "The weather in Dallas TX is 98 degrees fahrenheit with mostly "
|
||||
"cloudy skies and a chance of rain in the evening.",
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "chatcmpl-tool-d027061e1bd21cda48bee7da829c1f5b",
|
||||
"content": "The weather in Orlando FL is 78 degrees fahrenheit with clear"
|
||||
"skies.",
|
||||
},
|
||||
]
|
||||
Reference in New Issue
Block a user