feat: Improve Mistral and Qwen25 function call parsing (#6597)

This commit is contained in:
Chang Su
2025-05-25 23:07:23 -07:00
committed by GitHub
parent 65f091310c
commit 16f69b1f65
7 changed files with 318 additions and 61 deletions

View File

@@ -1,4 +1,5 @@
import json
import logging
import re
from typing import List
@@ -11,12 +12,14 @@ from sglang.srt.function_call.core_types import (
from sglang.srt.function_call.ebnf_composer import EBNFComposer
from sglang.srt.openai_api.protocol import Tool
logger = logging.getLogger(__name__)
class MistralDetector(BaseFormatDetector):
"""
Detector for Mistral models.
Assumes function call format:
[TOOL_CALLS] [{"name":"xxx", "arguments":{...}}]
[TOOL_CALLS] [{"name":"func1", "arguments":{...}}, {"name":"func2", "arguments":{...}}]
"""
def __init__(self):
@@ -32,21 +35,6 @@ class MistralDetector(BaseFormatDetector):
"""Check if the text contains a Mistral format tool call."""
return self.bot_token in text
def _clean_text(self, text: str) -> str:
"""
clean text to only leave ''[TOOL_CALLS] [{"name": xxx, "arguments": {xxx}}]'
for example,
text = '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]\n\nToday\'s weather in Boston is :{function call result} (in Fahrenheit)\n\nIf you prefer Celsius, please let me know.'
return '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]'
The key pattern is [TOOL_CALLS] [...]
"""
# TODO: check if Mistral supports multiple tool calls, currently assume only support one tool call
find_results = re.findall(r"\[TOOL_CALLS\] \[.*?\]", text, re.DOTALL)
if len(find_results) > 0:
return find_results[0]
else:
return ""
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
"""
One-time parsing: Detects and parses tool calls in the provided text.
@@ -57,17 +45,74 @@ class MistralDetector(BaseFormatDetector):
"""
idx = text.find(self.bot_token)
normal_text = text[:idx].strip() if idx != -1 else text
text = self._clean_text(text)
tool_content = text.replace("[TOOL_CALLS]", "").strip()
raw_tool_calls = self.tool_call_regex.findall(tool_content)
if self.bot_token not in text:
return StreamingParseResult(normal_text=normal_text, calls=[])
# Extract the JSON array part from [TOOL_CALLS] [...]
# Use bracket counting to properly handle nested brackets in JSON content
json_array_str = self._extract_json_array(text)
if not json_array_str:
return StreamingParseResult(normal_text=normal_text, calls=[])
calls = []
if len(raw_tool_calls) > 0:
raw_tool_call = raw_tool_calls[0]
function_call_arr = json.loads(raw_tool_call)
for match_result in function_call_arr:
calls.extend(self.parse_base_json(match_result, tools))
try:
function_call_arr = json.loads(json_array_str)
# Handle both single object and array of objects
if not isinstance(function_call_arr, list):
function_call_arr = [function_call_arr]
calls = self.parse_base_json(function_call_arr, tools)
except json.JSONDecodeError as e:
logger.warning(
f"Failed to parse JSON part: {json_array_str}, JSON parse error: {str(e)}"
)
return StreamingParseResult(normal_text=normal_text, calls=calls)
def _extract_json_array(self, text: str) -> str:
"""
Extract the JSON array part using bracket counting to handle nested brackets.
:param text: The complete text containing [TOOL_CALLS] [...]
:return: The JSON array string or None if not found
"""
start_idx = text.find(self.bot_token)
if start_idx == -1:
return None
# Start from the opening bracket after [TOOL_CALLS]
json_start = (
start_idx + len(self.bot_token) - 1
) # -1 to include the opening bracket
bracket_count = 0
in_string = False
escape_next = False
for i in range(json_start, len(text)):
char = text[i]
if escape_next:
escape_next = False
continue
if char == "\\":
escape_next = True
continue
if char == '"' and not escape_next:
in_string = not in_string
continue
if not in_string:
if char == "[":
bracket_count += 1
elif char == "]":
bracket_count -= 1
if bracket_count == 0:
return text[json_start : i + 1]
return None
def structure_info(self) -> _GetInfoFunc:
return lambda name: StructureInfo(
begin='[TOOL_CALLS] [{"name":"' + name + '", "arguments":',
@@ -78,7 +123,8 @@ class MistralDetector(BaseFormatDetector):
def build_ebnf(self, tools: List[Tool]):
return EBNFComposer.build_ebnf(
tools,
bot_token=self.bot_token,
eot_token=self.eot_token,
sequence_start_token=self.bot_token,
sequence_end_token=self.eot_token,
function_format="json",
tool_call_separator=", ",
)