feat: support pythonic tool call and index in tool call streaming (#5725)
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import ast
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
@@ -664,6 +665,101 @@ class MultiFormatParser:
|
||||
return final_normal_text, final_calls
|
||||
|
||||
|
||||
class PythonicDetector(BaseFormatDetector):
|
||||
"""
|
||||
Detector for Llama-3.2 and Llama-4 models with pythonic tool call format.
|
||||
Assumes function call format:
|
||||
[tool1(arg1=val1, arg2=val2), tool2(arg1=val3)]
|
||||
Arguments are Python literals (not JSON).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.tool_call_regex = re.compile(
|
||||
r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
def has_tool_call(self, text: str) -> bool:
|
||||
return bool(self.tool_call_regex.match(text.strip()))
|
||||
|
||||
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||||
# Try parsing the text as a Python list of function calls
|
||||
text = text.strip()
|
||||
if not (text.startswith("[") and text.endswith("]")):
|
||||
# Not a pythonic tool call format
|
||||
return StreamingParseResult(normal_text=text, calls=[])
|
||||
try:
|
||||
module = ast.parse(text)
|
||||
parsed = getattr(module.body[0], "value", None)
|
||||
if not (
|
||||
isinstance(parsed, ast.List)
|
||||
and all(isinstance(e, ast.Call) for e in parsed.elts)
|
||||
):
|
||||
return StreamingParseResult(normal_text=text, calls=[])
|
||||
calls = []
|
||||
tool_indices = {
|
||||
tool.function.name: i
|
||||
for i, tool in enumerate(tools)
|
||||
if tool.function.name
|
||||
}
|
||||
for call in parsed.elts:
|
||||
if not isinstance(call.func, ast.Name):
|
||||
continue
|
||||
function_name = call.func.id
|
||||
arguments = {}
|
||||
for keyword in call.keywords:
|
||||
arguments[keyword.arg] = self._get_parameter_value(keyword.value)
|
||||
calls.append(
|
||||
ToolCallItem(
|
||||
tool_index=tool_indices.get(function_name, -1),
|
||||
name=function_name,
|
||||
parameters=json.dumps(arguments, ensure_ascii=False),
|
||||
)
|
||||
)
|
||||
return StreamingParseResult(normal_text="", calls=calls)
|
||||
except Exception:
|
||||
logger.exception("Error in pythonic tool call parsing.")
|
||||
return StreamingParseResult(normal_text=text, calls=[])
|
||||
|
||||
def parse_streaming_increment(
|
||||
self, new_text: str, tools: List[Tool]
|
||||
) -> StreamingParseResult:
|
||||
"""
|
||||
Streaming incremental parsing for pythonic tool calls.
|
||||
Buffers input until a complete pythonic tool call (from [ to ]) is found,
|
||||
then parses and emits any detected calls.
|
||||
"""
|
||||
self._buffer += new_text
|
||||
start = self._buffer.find("[")
|
||||
end = self._buffer.find("]", start)
|
||||
if start != -1 and end != -1:
|
||||
call_text = self._buffer[start : end + 1]
|
||||
result = self.detect_and_parse(call_text, tools)
|
||||
self._buffer = self._buffer[end + 1 :]
|
||||
return result
|
||||
return StreamingParseResult(normal_text="")
|
||||
|
||||
def _get_parameter_value(self, val):
|
||||
if isinstance(val, ast.Constant):
|
||||
return val.value
|
||||
elif isinstance(val, ast.Dict):
|
||||
return {
|
||||
k.value: self._get_parameter_value(v)
|
||||
for k, v in zip(val.keys, val.values)
|
||||
}
|
||||
elif isinstance(val, ast.List):
|
||||
return [self._get_parameter_value(v) for v in val.elts]
|
||||
else:
|
||||
raise ValueError("Tool call arguments must be literals")
|
||||
|
||||
def structure_info(self) -> _GetInfoFunc:
|
||||
def info(name: str):
|
||||
return StructureInfo(begin="[", end="]", trigger="")
|
||||
|
||||
return info
|
||||
|
||||
|
||||
class FunctionCallParser:
|
||||
"""
|
||||
In streaming scenarios, each time new_text is received, it calls multi_format_parser.parse_streaming_increment
|
||||
@@ -675,6 +771,7 @@ class FunctionCallParser:
|
||||
"qwen25": Qwen25Detector,
|
||||
"mistral": MistralDetector,
|
||||
"deepseekv3": DeepSeekV3Detector,
|
||||
"pythonic": PythonicDetector,
|
||||
}
|
||||
|
||||
def __init__(self, tools: List[Tool], tool_call_parser: str):
|
||||
|
||||
@@ -1618,6 +1618,7 @@ async def v1_chat_completions(
|
||||
|
||||
tool_call = ToolCall(
|
||||
id=str(call_item.tool_index),
|
||||
index=call_item.tool_index,
|
||||
function=FunctionResponse(
|
||||
name=call_item.name,
|
||||
arguments=call_item.parameters,
|
||||
|
||||
@@ -389,6 +389,7 @@ class ToolCall(BaseModel):
|
||||
"""Tool call response."""
|
||||
|
||||
id: str
|
||||
index: Optional[int] = None
|
||||
type: Literal["function"] = "function"
|
||||
function: FunctionResponse
|
||||
|
||||
|
||||
@@ -1107,9 +1107,9 @@ class ServerArgs:
|
||||
parser.add_argument(
|
||||
"--tool-call-parser",
|
||||
type=str,
|
||||
choices=["qwen25", "mistral", "llama3", "deepseekv3"],
|
||||
choices=["qwen25", "mistral", "llama3", "deepseekv3", "pythonic"],
|
||||
default=ServerArgs.tool_call_parser,
|
||||
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', and 'llama3'.",
|
||||
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', and 'pythonic'.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-hierarchical-cache",
|
||||
|
||||
Reference in New Issue
Block a user