Fix Llama3.3 tool call support (#4320)
This commit is contained in:
@@ -318,6 +318,10 @@ class Qwen25Detector(BaseFormatDetector):
|
|||||||
self.bot_token = "<tool_call>"
|
self.bot_token = "<tool_call>"
|
||||||
self.eot_token = "</tool_call>"
|
self.eot_token = "</tool_call>"
|
||||||
|
|
||||||
|
def has_tool_call(self, text: str) -> bool:
|
||||||
|
"""Check if the text contains a Qwen 2.5 format tool call."""
|
||||||
|
return self.bot_token in text
|
||||||
|
|
||||||
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
|
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
|
||||||
"""
|
"""
|
||||||
One-time parsing: Detects and parses tool calls in the provided text.
|
One-time parsing: Detects and parses tool calls in the provided text.
|
||||||
@@ -352,6 +356,10 @@ class MistralDetector(BaseFormatDetector):
|
|||||||
self.bot_token = "[TOOL_CALLS] ["
|
self.bot_token = "[TOOL_CALLS] ["
|
||||||
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
|
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
|
||||||
|
|
||||||
|
def has_tool_call(self, text: str) -> bool:
|
||||||
|
"""Check if the text contains a Mistral format tool call."""
|
||||||
|
return self.bot_token in text
|
||||||
|
|
||||||
def _clean_text(self, text: str) -> str:
|
def _clean_text(self, text: str) -> str:
|
||||||
"""
|
"""
|
||||||
clean text to only leave ''[TOOL_CALLS] [{"name": xxx, "arguments": {xxx}}]'
|
clean text to only leave ''[TOOL_CALLS] [{"name": xxx, "arguments": {xxx}}]'
|
||||||
@@ -397,12 +405,21 @@ class Llama32Detector(BaseFormatDetector):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.bot_token = "<|python_tag|>"
|
self.bot_token = "<|python_tag|>"
|
||||||
|
|
||||||
|
def has_tool_call(self, text: str) -> bool:
|
||||||
|
"""Check if the text contains a Llama 3.2 format tool call."""
|
||||||
|
# depending on the prompt format the Llama model may or may not
|
||||||
|
# prefix the output with the <|python_tag|> token
|
||||||
|
return "<|python_tag|>" in text or text.startswith("{")
|
||||||
|
|
||||||
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
|
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
|
||||||
"""Parse function calls from text, handling multiple JSON objects."""
|
"""Parse function calls from text, handling multiple JSON objects."""
|
||||||
if "<|python_tag|>" not in text:
|
if "<|python_tag|>" not in text and not text.startswith("{"):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
_, action_text = text.split("<|python_tag|>")
|
if "<|python_tag|>" in text:
|
||||||
|
_, action_text = text.split("<|python_tag|>")
|
||||||
|
else:
|
||||||
|
action_text = text
|
||||||
|
|
||||||
# Split by semicolon and process each part
|
# Split by semicolon and process each part
|
||||||
json_parts = [part.strip() for part in action_text.split(";") if part.strip()]
|
json_parts = [part.strip() for part in action_text.split(";") if part.strip()]
|
||||||
@@ -501,6 +518,20 @@ class FunctionCallParser:
|
|||||||
self.multi_format_parser = MultiFormatParser(detectors)
|
self.multi_format_parser = MultiFormatParser(detectors)
|
||||||
self.tools = tools
|
self.tools = tools
|
||||||
|
|
||||||
|
def has_tool_call(self, text: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the given text contains a tool call in the format supported by this parser.
|
||||||
|
This delegates to the detector's implementation.
|
||||||
|
|
||||||
|
:param text: The text to check for tool calls
|
||||||
|
:return: True if the text contains a tool call, False otherwise
|
||||||
|
"""
|
||||||
|
# Check all detectors in the multi_format_parser
|
||||||
|
for detector in self.multi_format_parser.detectors:
|
||||||
|
if detector.has_tool_call(text):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def parse_non_stream(self, full_text: str):
|
def parse_non_stream(self, full_text: str):
|
||||||
"""
|
"""
|
||||||
Non-streaming call: one-time parsing
|
Non-streaming call: one-time parsing
|
||||||
|
|||||||
@@ -1115,27 +1115,29 @@ def v1_chat_generate_response(
|
|||||||
else:
|
else:
|
||||||
reasoning_text = None
|
reasoning_text = None
|
||||||
|
|
||||||
if tool_choice != "none" and any([i in text for i in TOOLS_TAG_LIST]):
|
if tool_choice != "none" and tools:
|
||||||
if finish_reason == "stop":
|
parser = FunctionCallParser(tools, tool_call_parser)
|
||||||
finish_reason = "tool_calls"
|
if parser.has_tool_call(text):
|
||||||
try:
|
if finish_reason["type"] == "stop":
|
||||||
parser = FunctionCallParser(tools, tool_call_parser)
|
finish_reason["type"] = "tool_calls"
|
||||||
full_normal_text, call_info_list = parser.parse_non_stream(text)
|
finish_reason["matched"] = None
|
||||||
tool_calls = [
|
try:
|
||||||
ToolCall(
|
full_normal_text, call_info_list = parser.parse_non_stream(text)
|
||||||
id=str(call_info.tool_index),
|
tool_calls = [
|
||||||
function=FunctionResponse(
|
ToolCall(
|
||||||
name=call_info.name, arguments=call_info.parameters
|
id=str(call_info.tool_index),
|
||||||
),
|
function=FunctionResponse(
|
||||||
|
name=call_info.name, arguments=call_info.parameters
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for call_info in call_info_list
|
||||||
|
]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Exception: {e}")
|
||||||
|
return create_error_response(
|
||||||
|
HTTPStatus.BAD_REQUEST,
|
||||||
|
"Failed to parse fc related info to json format!",
|
||||||
)
|
)
|
||||||
for call_info in call_info_list
|
|
||||||
]
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Exception: {e}")
|
|
||||||
return create_error_response(
|
|
||||||
HTTPStatus.BAD_REQUEST,
|
|
||||||
"Failed to parse fc related info to json format!",
|
|
||||||
)
|
|
||||||
|
|
||||||
if to_file:
|
if to_file:
|
||||||
# to make the choice data json serializable
|
# to make the choice data json serializable
|
||||||
|
|||||||
Reference in New Issue
Block a user