Fix Llama3.3 tool call support (#4320)

This commit is contained in:
Chang Su
2025-03-13 14:01:41 -07:00
committed by GitHub
parent c6d7f8d370
commit 5fe79605a8
2 changed files with 55 additions and 22 deletions

View File

@@ -318,6 +318,10 @@ class Qwen25Detector(BaseFormatDetector):
self.bot_token = "<tool_call>"
self.eot_token = "</tool_call>"
def has_tool_call(self, text: str) -> bool:
"""Check if the text contains a Qwen 2.5 format tool call."""
return self.bot_token in text
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
"""
One-time parsing: Detects and parses tool calls in the provided text.
@@ -352,6 +356,10 @@ class MistralDetector(BaseFormatDetector):
self.bot_token = "[TOOL_CALLS] ["
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
def has_tool_call(self, text: str) -> bool:
"""Check if the text contains a Mistral format tool call."""
return self.bot_token in text
def _clean_text(self, text: str) -> str:
"""
clean text to only leave ''[TOOL_CALLS] [{"name": xxx, "arguments": {xxx}}]'
@@ -397,12 +405,21 @@ class Llama32Detector(BaseFormatDetector):
super().__init__()
self.bot_token = "<|python_tag|>"
def has_tool_call(self, text: str) -> bool:
"""Check if the text contains a Llama 3.2 format tool call."""
# depending on the prompt format the Llama model may or may not
# prefix the output with the <|python_tag|> token
return "<|python_tag|>" in text or text.startswith("{")
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
"""Parse function calls from text, handling multiple JSON objects."""
if "<|python_tag|>" not in text:
if "<|python_tag|>" not in text and not text.startswith("{"):
return []
_, action_text = text.split("<|python_tag|>")
if "<|python_tag|>" in text:
_, action_text = text.split("<|python_tag|>")
else:
action_text = text
# Split by semicolon and process each part
json_parts = [part.strip() for part in action_text.split(";") if part.strip()]
@@ -501,6 +518,20 @@ class FunctionCallParser:
self.multi_format_parser = MultiFormatParser(detectors)
self.tools = tools
def has_tool_call(self, text: str) -> bool:
"""
Check if the given text contains a tool call in the format supported by this parser.
This delegates to the detector's implementation.
:param text: The text to check for tool calls
:return: True if the text contains a tool call, False otherwise
"""
# Check all detectors in the multi_format_parser
for detector in self.multi_format_parser.detectors:
if detector.has_tool_call(text):
return True
return False
def parse_non_stream(self, full_text: str):
"""
Non-streaming call: one-time parsing