feat: support pythonic tool call and index in tool call streaming (#5725)

This commit is contained in:
Chang Su
2025-04-29 17:30:44 -07:00
committed by GitHub
parent e4b6133b78
commit 2b06484bd1
8 changed files with 541 additions and 3 deletions

View File

@@ -1,3 +1,4 @@
import ast
import json
import logging
import re
@@ -664,6 +665,101 @@ class MultiFormatParser:
return final_normal_text, final_calls
class PythonicDetector(BaseFormatDetector):
"""
Detector for Llama-3.2 and Llama-4 models with pythonic tool call format.
Assumes function call format:
[tool1(arg1=val1, arg2=val2), tool2(arg1=val3)]
Arguments are Python literals (not JSON).
"""
def __init__(self):
super().__init__()
self.tool_call_regex = re.compile(
r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]",
re.DOTALL,
)
def has_tool_call(self, text: str) -> bool:
return bool(self.tool_call_regex.match(text.strip()))
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
# Try parsing the text as a Python list of function calls
text = text.strip()
if not (text.startswith("[") and text.endswith("]")):
# Not a pythonic tool call format
return StreamingParseResult(normal_text=text, calls=[])
try:
module = ast.parse(text)
parsed = getattr(module.body[0], "value", None)
if not (
isinstance(parsed, ast.List)
and all(isinstance(e, ast.Call) for e in parsed.elts)
):
return StreamingParseResult(normal_text=text, calls=[])
calls = []
tool_indices = {
tool.function.name: i
for i, tool in enumerate(tools)
if tool.function.name
}
for call in parsed.elts:
if not isinstance(call.func, ast.Name):
continue
function_name = call.func.id
arguments = {}
for keyword in call.keywords:
arguments[keyword.arg] = self._get_parameter_value(keyword.value)
calls.append(
ToolCallItem(
tool_index=tool_indices.get(function_name, -1),
name=function_name,
parameters=json.dumps(arguments, ensure_ascii=False),
)
)
return StreamingParseResult(normal_text="", calls=calls)
except Exception:
logger.exception("Error in pythonic tool call parsing.")
return StreamingParseResult(normal_text=text, calls=[])
def parse_streaming_increment(
self, new_text: str, tools: List[Tool]
) -> StreamingParseResult:
"""
Streaming incremental parsing for pythonic tool calls.
Buffers input until a complete pythonic tool call (from [ to ]) is found,
then parses and emits any detected calls.
"""
self._buffer += new_text
start = self._buffer.find("[")
end = self._buffer.find("]", start)
if start != -1 and end != -1:
call_text = self._buffer[start : end + 1]
result = self.detect_and_parse(call_text, tools)
self._buffer = self._buffer[end + 1 :]
return result
return StreamingParseResult(normal_text="")
def _get_parameter_value(self, val):
if isinstance(val, ast.Constant):
return val.value
elif isinstance(val, ast.Dict):
return {
k.value: self._get_parameter_value(v)
for k, v in zip(val.keys, val.values)
}
elif isinstance(val, ast.List):
return [self._get_parameter_value(v) for v in val.elts]
else:
raise ValueError("Tool call arguments must be literals")
def structure_info(self) -> _GetInfoFunc:
def info(name: str):
return StructureInfo(begin="[", end="]", trigger="")
return info
class FunctionCallParser:
"""
In streaming scenarios, each time new_text is received, it calls multi_format_parser.parse_streaming_increment
@@ -675,6 +771,7 @@ class FunctionCallParser:
"qwen25": Qwen25Detector,
"mistral": MistralDetector,
"deepseekv3": DeepSeekV3Detector,
"pythonic": PythonicDetector,
}
def __init__(self, tools: List[Tool], tool_call_parser: str):

View File

@@ -1618,6 +1618,7 @@ async def v1_chat_completions(
tool_call = ToolCall(
id=str(call_item.tool_index),
index=call_item.tool_index,
function=FunctionResponse(
name=call_item.name,
arguments=call_item.parameters,

View File

@@ -389,6 +389,7 @@ class ToolCall(BaseModel):
"""Tool call response."""
id: str
index: Optional[int] = None
type: Literal["function"] = "function"
function: FunctionResponse

View File

@@ -1107,9 +1107,9 @@ class ServerArgs:
parser.add_argument(
"--tool-call-parser",
type=str,
choices=["qwen25", "mistral", "llama3", "deepseekv3"],
choices=["qwen25", "mistral", "llama3", "deepseekv3", "pythonic"],
default=ServerArgs.tool_call_parser,
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', and 'llama3'.",
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', and 'pythonic'.",
)
parser.add_argument(
"--enable-hierarchical-cache",