Sync from v0.13
This commit is contained in:
229
vllm/tool_parsers/utils.py
Normal file
229
vllm/tool_parsers/utils.py
Normal file
@@ -0,0 +1,229 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from json import JSONDecodeError, JSONDecoder
|
||||
from typing import Any
|
||||
|
||||
import partial_json_parser
|
||||
from openai.types.responses import (
|
||||
FunctionTool,
|
||||
ToolChoiceFunction,
|
||||
)
|
||||
from openai.types.responses.tool import Tool
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionNamedToolChoiceParam,
|
||||
ChatCompletionToolsParam,
|
||||
)
|
||||
|
||||
|
||||
def find_common_prefix(s1: str, s2: str) -> str:
|
||||
"""
|
||||
Finds a common prefix that is shared between two strings, if there is one.
|
||||
Order of arguments is NOT important.
|
||||
|
||||
This function is provided as a UTILITY for extracting information from JSON
|
||||
generated by partial_json_parser, to help in ensuring that the right tokens
|
||||
are returned in streaming, so that close-quotes, close-brackets and
|
||||
close-braces are not returned prematurely.
|
||||
|
||||
e.g. find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') ->
|
||||
'{"fruit": "ap'
|
||||
"""
|
||||
prefix = ""
|
||||
min_length = min(len(s1), len(s2))
|
||||
for i in range(0, min_length):
|
||||
if s1[i] == s2[i]:
|
||||
prefix += s1[i]
|
||||
else:
|
||||
break
|
||||
return prefix
|
||||
|
||||
|
||||
def find_common_suffix(s1: str, s2: str) -> str:
|
||||
"""
|
||||
Finds a common suffix shared between two strings, if there is one. Order of
|
||||
arguments is NOT important.
|
||||
Stops when the suffix ends OR it hits an alphanumeric character
|
||||
|
||||
e.g. find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '"}'
|
||||
"""
|
||||
suffix = ""
|
||||
min_length = min(len(s1), len(s2))
|
||||
for i in range(1, min_length + 1):
|
||||
if s1[-i] == s2[-i] and not s1[-i].isalnum():
|
||||
suffix = s1[-i] + suffix
|
||||
else:
|
||||
break
|
||||
return suffix
|
||||
|
||||
|
||||
def extract_intermediate_diff(curr: str, old: str) -> str:
|
||||
"""
|
||||
Given two strings, extract the difference in the middle between two strings
|
||||
that are known to have a common prefix and/or suffix.
|
||||
|
||||
This function is provided as a UTILITY for extracting information from JSON
|
||||
generated by partial_json_parser, to help in ensuring that the right tokens
|
||||
are returned in streaming, so that close-quotes, close-brackets and
|
||||
close-braces are not returned prematurely. The order of arguments IS
|
||||
important - the new version of the partially-parsed JSON must be the first
|
||||
argument, and the secnod argument must be from the previous generation.
|
||||
|
||||
What it returns, is tokens that should be streamed to the client.
|
||||
|
||||
e.g. extract_intermediate_diff('{"fruit": "apple"}', '{"fruit": "ap"}')
|
||||
-> 'ple'
|
||||
|
||||
"""
|
||||
suffix = find_common_suffix(curr, old)
|
||||
|
||||
old = old[::-1].replace(suffix[::-1], "", 1)[::-1]
|
||||
prefix = find_common_prefix(curr, old)
|
||||
diff = curr
|
||||
if len(suffix):
|
||||
diff = diff[::-1].replace(suffix[::-1], "", 1)[::-1]
|
||||
|
||||
if len(prefix):
|
||||
# replace the prefix only once in case it's mirrored
|
||||
diff = diff.replace(prefix, "", 1)
|
||||
|
||||
return diff
|
||||
|
||||
|
||||
def find_all_indices(string: str, substring: str) -> list[int]:
|
||||
"""
|
||||
Find all (starting) indices of a substring in a given string. Useful for
|
||||
tool call extraction
|
||||
"""
|
||||
indices = []
|
||||
index = -1
|
||||
while True:
|
||||
index = string.find(substring, index + 1)
|
||||
if index == -1:
|
||||
break
|
||||
indices.append(index)
|
||||
return indices
|
||||
|
||||
|
||||
# partial_json_parser doesn't support extra data and
|
||||
# JSONDecoder.raw_decode doesn't support partial JSON
|
||||
def partial_json_loads(input_str: str, flags: Allow) -> tuple[Any, int]:
|
||||
try:
|
||||
return (partial_json_parser.loads(input_str, flags), len(input_str))
|
||||
except JSONDecodeError as e:
|
||||
if "Extra data" in e.msg:
|
||||
dec = JSONDecoder()
|
||||
return dec.raw_decode(input_str)
|
||||
raise
|
||||
|
||||
|
||||
def is_complete_json(input_str: str) -> bool:
|
||||
try:
|
||||
json.loads(input_str)
|
||||
return True
|
||||
except JSONDecodeError:
|
||||
return False
|
||||
|
||||
|
||||
def consume_space(i: int, s: str) -> int:
|
||||
while i < len(s) and s[i].isspace():
|
||||
i += 1
|
||||
return i
|
||||
|
||||
|
||||
def _extract_tool_info(
|
||||
tool: Tool | ChatCompletionToolsParam,
|
||||
) -> tuple[str, dict[str, Any] | None]:
|
||||
if isinstance(tool, FunctionTool):
|
||||
return tool.name, tool.parameters
|
||||
elif isinstance(tool, ChatCompletionToolsParam):
|
||||
return tool.function.name, tool.function.parameters
|
||||
else:
|
||||
raise TypeError(f"Unsupported tool type: {type(tool)}")
|
||||
|
||||
|
||||
def _get_tool_schema_from_tool(tool: Tool | ChatCompletionToolsParam) -> dict:
|
||||
name, params = _extract_tool_info(tool)
|
||||
params = params if params else {"type": "object", "properties": {}}
|
||||
return {
|
||||
"properties": {
|
||||
"name": {"type": "string", "enum": [name]},
|
||||
"parameters": params,
|
||||
},
|
||||
"required": ["name", "parameters"],
|
||||
}
|
||||
|
||||
|
||||
def _get_tool_schema_defs(
|
||||
tools: list[Tool | ChatCompletionToolsParam],
|
||||
) -> dict:
|
||||
all_defs: dict[str, dict[str, Any]] = {}
|
||||
for tool in tools:
|
||||
_, params = _extract_tool_info(tool)
|
||||
if params is None:
|
||||
continue
|
||||
defs = params.pop("$defs", {})
|
||||
for def_name, def_schema in defs.items():
|
||||
if def_name in all_defs and all_defs[def_name] != def_schema:
|
||||
raise ValueError(
|
||||
f"Tool definition '{def_name}' has multiple schemas, "
|
||||
"which is not supported."
|
||||
)
|
||||
all_defs[def_name] = def_schema
|
||||
return all_defs
|
||||
|
||||
|
||||
def _get_json_schema_from_tools(
|
||||
tools: list[Tool | ChatCompletionToolsParam],
|
||||
) -> dict:
|
||||
json_schema = {
|
||||
"type": "array",
|
||||
"minItems": 1,
|
||||
"items": {
|
||||
"type": "object",
|
||||
"anyOf": [_get_tool_schema_from_tool(tool) for tool in tools],
|
||||
},
|
||||
}
|
||||
json_schema_defs = _get_tool_schema_defs(tools)
|
||||
if json_schema_defs:
|
||||
json_schema["$defs"] = json_schema_defs
|
||||
return json_schema
|
||||
|
||||
|
||||
def get_json_schema_from_tools(
|
||||
tool_choice: str | ToolChoiceFunction | ChatCompletionNamedToolChoiceParam,
|
||||
tools: list[FunctionTool | ChatCompletionToolsParam] | None,
|
||||
) -> str | dict | None:
|
||||
# tool_choice: "none"
|
||||
if tool_choice in ("none", None) or tools is None:
|
||||
return None
|
||||
# tool_choice: Forced Function (Responses)
|
||||
if (not isinstance(tool_choice, str)) and isinstance(
|
||||
tool_choice, ToolChoiceFunction
|
||||
):
|
||||
tool_name = tool_choice.name
|
||||
tool_map = {tool.name: tool for tool in tools if isinstance(tool, FunctionTool)}
|
||||
if tool_name not in tool_map:
|
||||
raise ValueError(f"Tool '{tool_name}' has not been passed in `tools`.")
|
||||
return tool_map[tool_name].parameters
|
||||
# tool_choice: Forced Function (ChatCompletion)
|
||||
if (not isinstance(tool_choice, str)) and isinstance(
|
||||
tool_choice, ChatCompletionNamedToolChoiceParam
|
||||
):
|
||||
tool_name = tool_choice.function.name
|
||||
tool_map = {
|
||||
tool.function.name: tool
|
||||
for tool in tools
|
||||
if isinstance(tool, ChatCompletionToolsParam)
|
||||
}
|
||||
if tool_name not in tool_map:
|
||||
raise ValueError(f"Tool '{tool_name}' has not been passed in `tools`.")
|
||||
return tool_map[tool_name].function.parameters
|
||||
# tool_choice: "required"
|
||||
if tool_choice == "required":
|
||||
return _get_json_schema_from_tools(tools)
|
||||
# tool_choice: "auto"
|
||||
return None
|
||||
Reference in New Issue
Block a user