feat: Improve Mistral and Qwen25 function call parsing (#6597)
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
@@ -11,12 +12,14 @@ from sglang.srt.function_call.core_types import (
|
||||
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
||||
from sglang.srt.openai_api.protocol import Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MistralDetector(BaseFormatDetector):
|
||||
"""
|
||||
Detector for Mistral models.
|
||||
Assumes function call format:
|
||||
[TOOL_CALLS] [{"name":"xxx", "arguments":{...}}]
|
||||
[TOOL_CALLS] [{"name":"func1", "arguments":{...}}, {"name":"func2", "arguments":{...}}]
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
@@ -32,21 +35,6 @@ class MistralDetector(BaseFormatDetector):
|
||||
"""Check if the text contains a Mistral format tool call."""
|
||||
return self.bot_token in text
|
||||
|
||||
def _clean_text(self, text: str) -> str:
|
||||
"""
|
||||
clean text to only leave ''[TOOL_CALLS] [{"name": xxx, "arguments": {xxx}}]'
|
||||
for example,
|
||||
text = '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]\n\nToday\'s weather in Boston is :{function call result} (in Fahrenheit)\n\nIf you prefer Celsius, please let me know.'
|
||||
return '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]'
|
||||
The key pattern is [TOOL_CALLS] [...]
|
||||
"""
|
||||
# TODO: check if Mistral supports multiple tool calls, currently assume only support one tool call
|
||||
find_results = re.findall(r"\[TOOL_CALLS\] \[.*?\]", text, re.DOTALL)
|
||||
if len(find_results) > 0:
|
||||
return find_results[0]
|
||||
else:
|
||||
return ""
|
||||
|
||||
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||||
"""
|
||||
One-time parsing: Detects and parses tool calls in the provided text.
|
||||
@@ -57,17 +45,74 @@ class MistralDetector(BaseFormatDetector):
|
||||
"""
|
||||
idx = text.find(self.bot_token)
|
||||
normal_text = text[:idx].strip() if idx != -1 else text
|
||||
text = self._clean_text(text)
|
||||
tool_content = text.replace("[TOOL_CALLS]", "").strip()
|
||||
raw_tool_calls = self.tool_call_regex.findall(tool_content)
|
||||
|
||||
if self.bot_token not in text:
|
||||
return StreamingParseResult(normal_text=normal_text, calls=[])
|
||||
|
||||
# Extract the JSON array part from [TOOL_CALLS] [...]
|
||||
# Use bracket counting to properly handle nested brackets in JSON content
|
||||
json_array_str = self._extract_json_array(text)
|
||||
if not json_array_str:
|
||||
return StreamingParseResult(normal_text=normal_text, calls=[])
|
||||
|
||||
calls = []
|
||||
if len(raw_tool_calls) > 0:
|
||||
raw_tool_call = raw_tool_calls[0]
|
||||
function_call_arr = json.loads(raw_tool_call)
|
||||
for match_result in function_call_arr:
|
||||
calls.extend(self.parse_base_json(match_result, tools))
|
||||
try:
|
||||
function_call_arr = json.loads(json_array_str)
|
||||
# Handle both single object and array of objects
|
||||
if not isinstance(function_call_arr, list):
|
||||
function_call_arr = [function_call_arr]
|
||||
calls = self.parse_base_json(function_call_arr, tools)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(
|
||||
f"Failed to parse JSON part: {json_array_str}, JSON parse error: {str(e)}"
|
||||
)
|
||||
|
||||
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
||||
|
||||
def _extract_json_array(self, text: str) -> str:
|
||||
"""
|
||||
Extract the JSON array part using bracket counting to handle nested brackets.
|
||||
|
||||
:param text: The complete text containing [TOOL_CALLS] [...]
|
||||
:return: The JSON array string or None if not found
|
||||
"""
|
||||
start_idx = text.find(self.bot_token)
|
||||
if start_idx == -1:
|
||||
return None
|
||||
|
||||
# Start from the opening bracket after [TOOL_CALLS]
|
||||
json_start = (
|
||||
start_idx + len(self.bot_token) - 1
|
||||
) # -1 to include the opening bracket
|
||||
bracket_count = 0
|
||||
in_string = False
|
||||
escape_next = False
|
||||
|
||||
for i in range(json_start, len(text)):
|
||||
char = text[i]
|
||||
|
||||
if escape_next:
|
||||
escape_next = False
|
||||
continue
|
||||
|
||||
if char == "\\":
|
||||
escape_next = True
|
||||
continue
|
||||
|
||||
if char == '"' and not escape_next:
|
||||
in_string = not in_string
|
||||
continue
|
||||
|
||||
if not in_string:
|
||||
if char == "[":
|
||||
bracket_count += 1
|
||||
elif char == "]":
|
||||
bracket_count -= 1
|
||||
if bracket_count == 0:
|
||||
return text[json_start : i + 1]
|
||||
|
||||
return None
|
||||
|
||||
def structure_info(self) -> _GetInfoFunc:
|
||||
return lambda name: StructureInfo(
|
||||
begin='[TOOL_CALLS] [{"name":"' + name + '", "arguments":',
|
||||
@@ -78,7 +123,8 @@ class MistralDetector(BaseFormatDetector):
|
||||
def build_ebnf(self, tools: List[Tool]):
|
||||
return EBNFComposer.build_ebnf(
|
||||
tools,
|
||||
bot_token=self.bot_token,
|
||||
eot_token=self.eot_token,
|
||||
sequence_start_token=self.bot_token,
|
||||
sequence_end_token=self.eot_token,
|
||||
function_format="json",
|
||||
tool_call_separator=", ",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user