feat: Improve Mistral and Qwen25 function call parsing (#6597)

This commit is contained in:
Chang Su
2025-05-25 23:07:23 -07:00
committed by GitHub
parent 65f091310c
commit 16f69b1f65
7 changed files with 318 additions and 61 deletions

View File

@@ -72,20 +72,51 @@ class BaseFormatDetector(ABC):
action = json.loads(text)
return StreamingParseResult(calls=self.parse_base_json(action, tools))
def ends_with_partial_token(self, buffer: str, bot_token: str) -> int:
"""
Check if buffer ends with a partial bot_token.
Return the length of the partial bot_token.
For some format, the bot_token is not a token in model's vocabulary, such as
`[TOOL_CALLS] [` in Mistral.
"""
for i in range(1, min(len(buffer) + 1, len(bot_token))):
if bot_token.startswith(buffer[-i:]):
return i
return 0
def parse_streaming_increment(
self, new_text: str, tools: List[Tool]
) -> StreamingParseResult:
"""
Streaming incremental parsing with tool validation.
This base implementation works best with formats where:
1. bot_token is followed immediately by JSON (e.g., bot_token + JSON_array)
2. JSON can be parsed incrementally using partial_json_loads
3. Multiple tool calls are separated by "; " or ", "
Examples of incompatible formats (need custom implementation, may reuse some logic from this class):
- Each tool call is wrapped in a separate block: See Qwen25Detector
- Multiple separate blocks: [TOOL_CALLS] [...] \n [TOOL_CALLS] [...]
- Tool call is Pythonic style
For incompatible formats, detectors should override this method with custom logic.
"""
# Append new text to buffer
self._buffer += new_text
current_text = self._buffer
if not (self.bot_token in current_text or current_text.startswith("{")):
self._buffer = ""
if self.eot_token in new_text:
new_text = new_text.replace(self.eot_token, "")
return StreamingParseResult(normal_text=new_text)
# Only clear buffer if we're sure no tool call is starting
if not self.ends_with_partial_token(self._buffer, self.bot_token):
normal_text = self._buffer
self._buffer = ""
if self.eot_token in normal_text:
normal_text = normal_text.replace(self.eot_token, "")
return StreamingParseResult(normal_text=normal_text)
else:
# Might be partial bot_token, keep buffering
return StreamingParseResult()
# Build tool indices if not already built
if not hasattr(self, "_tool_indices"):

View File

@@ -149,8 +149,8 @@ class DeepSeekV3Detector(BaseFormatDetector):
def build_ebnf(self, tools: List[Tool]):
return EBNFComposer.build_ebnf(
tools,
bot_token=self.bot_token,
eot_token=self.eot_token,
sequence_start_token=self.bot_token,
sequence_end_token=self.eot_token,
tool_call_separator="",
call_rule_fmt='"<tool▁call▁begin>function<tool▁sep>{name}\\n```json\\n" {arguments_rule} "\\n```<tool▁call▁end>"',
function_format="json",

View File

@@ -30,11 +30,6 @@ class EBNFComposer:
ws ::= [ \n\t]*
"""
TOOL_CALLS_MAP = {
"pythonic": '"[" function_call ("," function_call)* "]"',
"json": "function_call",
}
CALL_RULE_MAP = {
"pythonic": 'call_{name} ::= "{name}" "(" {arguments_rule} ")"',
"json": 'call_{name} ::= "{{" "\\"name\\"" ":" "\\"{name}\\"" ", " "\\"arguments\\"" ":" {arguments_rule} "}}"',
@@ -138,35 +133,54 @@ class EBNFComposer:
@staticmethod
def build_ebnf(
tools,
*,
call_rule_fmt: Optional[str] = None,
function_format: Literal["pythonic", "json"] = "json",
bot_token: Optional[str] = None,
eot_token: Optional[str] = None,
# Parameters for wrapping the entire sequence of tool calls
sequence_start_token: Optional[str] = None,
sequence_end_token: Optional[str] = None,
# Parameters for wrapping individual tool calls
individual_call_start_token: Optional[str] = None,
individual_call_end_token: Optional[str] = None,
# Parameter for separating multiple tool calls
tool_call_separator: Optional[str] = None,
call_rule_fmt: Optional[str] = None,
):
"""
Generalized EBNF builder for all detectors.
Args:
tools: List of Tool objects to generate EBNF grammar for
function_format: The format of function calls, either "pythonic" or "json"
sequence_start_token: Token that wraps the entire sequence of tool calls (start)
sequence_end_token: Token that wraps the entire sequence of tool calls (end)
individual_call_start_token: Token that wraps each individual tool call (start)
individual_call_end_token: Token that wraps each individual tool call (end)
tool_call_separator: The separator between multiple tool calls
call_rule_fmt: Optional custom format string for call_{name} rule. It should define each function call's format, with
the placeholders {name} for the function name and {arguments_rule} for the arguments rule. If None, a default
format based on function_format will be used.
function_format: The format of function calls, either "pythonic" or "json"
bot_token: The token that indicates the start of a tool call section
eot_token: The token that indicates the end of a tool call section
tool_call_separator: The separator between multiple tool calls
"""
# =================================================================
# Step 1: Determine the root tool calls rule
# =================================================================
if bot_token and eot_token:
if tool_call_separator:
root_rule = f'"{bot_token}" function_call ( "{tool_call_separator}" function_call )* "{eot_token}"'
else:
root_rule = f'"{bot_token}" function_call "{eot_token}"'
# Handle a single function call
if individual_call_start_token and individual_call_end_token:
function_call_unit = f'"{individual_call_start_token}" function_call "{individual_call_end_token}"'
else:
root_rule = EBNFComposer.TOOL_CALLS_MAP[function_format]
function_call_unit = "function_call"
# Handle multiple function calls with separators
if tool_call_separator is not None:
base_pattern = f'{function_call_unit} ( "{tool_call_separator}" {function_call_unit} )*'
else:
# Assume only support single function call
base_pattern = function_call_unit
# Apply sequence-level wrapping if needed
if sequence_start_token and sequence_end_token:
root_rule = (
f'"{sequence_start_token}" {base_pattern} "{sequence_end_token}"'
)
else:
root_rule = base_pattern
# =================================================================
# Step 2: Build the header rules

View File

@@ -1,4 +1,5 @@
import json
import logging
import re
from typing import List
@@ -11,12 +12,14 @@ from sglang.srt.function_call.core_types import (
from sglang.srt.function_call.ebnf_composer import EBNFComposer
from sglang.srt.openai_api.protocol import Tool
logger = logging.getLogger(__name__)
class MistralDetector(BaseFormatDetector):
"""
Detector for Mistral models.
Assumes function call format:
[TOOL_CALLS] [{"name":"xxx", "arguments":{...}}]
[TOOL_CALLS] [{"name":"func1", "arguments":{...}}, {"name":"func2", "arguments":{...}}]
"""
def __init__(self):
@@ -32,21 +35,6 @@ class MistralDetector(BaseFormatDetector):
"""Check if the text contains a Mistral format tool call."""
return self.bot_token in text
def _clean_text(self, text: str) -> str:
"""
clean text to only leave ''[TOOL_CALLS] [{"name": xxx, "arguments": {xxx}}]'
for example,
text = '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]\n\nToday\'s weather in Boston is :{function call result} (in Fahrenheit)\n\nIf you prefer Celsius, please let me know.'
return '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]'
The key pattern is [TOOL_CALLS] [...]
"""
# TODO: check if Mistral supports multiple tool calls, currently assume only support one tool call
find_results = re.findall(r"\[TOOL_CALLS\] \[.*?\]", text, re.DOTALL)
if len(find_results) > 0:
return find_results[0]
else:
return ""
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
"""
One-time parsing: Detects and parses tool calls in the provided text.
@@ -57,17 +45,74 @@ class MistralDetector(BaseFormatDetector):
"""
idx = text.find(self.bot_token)
normal_text = text[:idx].strip() if idx != -1 else text
text = self._clean_text(text)
tool_content = text.replace("[TOOL_CALLS]", "").strip()
raw_tool_calls = self.tool_call_regex.findall(tool_content)
if self.bot_token not in text:
return StreamingParseResult(normal_text=normal_text, calls=[])
# Extract the JSON array part from [TOOL_CALLS] [...]
# Use bracket counting to properly handle nested brackets in JSON content
json_array_str = self._extract_json_array(text)
if not json_array_str:
return StreamingParseResult(normal_text=normal_text, calls=[])
calls = []
if len(raw_tool_calls) > 0:
raw_tool_call = raw_tool_calls[0]
function_call_arr = json.loads(raw_tool_call)
for match_result in function_call_arr:
calls.extend(self.parse_base_json(match_result, tools))
try:
function_call_arr = json.loads(json_array_str)
# Handle both single object and array of objects
if not isinstance(function_call_arr, list):
function_call_arr = [function_call_arr]
calls = self.parse_base_json(function_call_arr, tools)
except json.JSONDecodeError as e:
logger.warning(
f"Failed to parse JSON part: {json_array_str}, JSON parse error: {str(e)}"
)
return StreamingParseResult(normal_text=normal_text, calls=calls)
def _extract_json_array(self, text: str) -> str:
"""
Extract the JSON array part using bracket counting to handle nested brackets.
:param text: The complete text containing [TOOL_CALLS] [...]
:return: The JSON array string or None if not found
"""
start_idx = text.find(self.bot_token)
if start_idx == -1:
return None
# Start from the opening bracket after [TOOL_CALLS]
json_start = (
start_idx + len(self.bot_token) - 1
) # -1 to include the opening bracket
bracket_count = 0
in_string = False
escape_next = False
for i in range(json_start, len(text)):
char = text[i]
if escape_next:
escape_next = False
continue
if char == "\\":
escape_next = True
continue
if char == '"' and not escape_next:
in_string = not in_string
continue
if not in_string:
if char == "[":
bracket_count += 1
elif char == "]":
bracket_count -= 1
if bracket_count == 0:
return text[json_start : i + 1]
return None
def structure_info(self) -> _GetInfoFunc:
return lambda name: StructureInfo(
begin='[TOOL_CALLS] [{"name":"' + name + '", "arguments":',
@@ -78,7 +123,8 @@ class MistralDetector(BaseFormatDetector):
def build_ebnf(self, tools: List[Tool]):
return EBNFComposer.build_ebnf(
tools,
bot_token=self.bot_token,
eot_token=self.eot_token,
sequence_start_token=self.bot_token,
sequence_end_token=self.eot_token,
function_format="json",
tool_call_separator=", ",
)

View File

@@ -156,8 +156,8 @@ class PythonicDetector(BaseFormatDetector):
def build_ebnf(self, tools: List[Tool]) -> Optional[str]:
return EBNFComposer.build_ebnf(
tools,
bot_token="[",
eot_token="]",
sequence_start_token="[",
sequence_end_token="]",
tool_call_separator=",",
function_format="pythonic",
)

View File

@@ -1,4 +1,5 @@
import json
import logging
import re
from typing import List
@@ -11,12 +12,14 @@ from sglang.srt.function_call.core_types import (
from sglang.srt.function_call.ebnf_composer import EBNFComposer
from sglang.srt.openai_api.protocol import Tool
logger = logging.getLogger(__name__)
class Qwen25Detector(BaseFormatDetector):
"""
Detector for Qwen 2.5 models.
Assumes function call format:
<tool_call>{"name":"xxx", "arguments":{...}}</tool_call>
<tool_call>\n{"name":"func1", "arguments":{...}}\n</tool_call>\n<tool_call>\n{"name":"func2", "arguments":{...}}\n</tool_call>
"""
def __init__(self):
@@ -24,8 +27,9 @@ class Qwen25Detector(BaseFormatDetector):
Initializes the detector with necessary state variables.
"""
super().__init__()
self.bot_token = "<tool_call>"
self.eot_token = "</tool_call>"
self.bot_token = "<tool_call>\n"
self.eot_token = "\n</tool_call>"
self._normal_text_buffer = "" # Buffer for handling partial end tokens
def has_tool_call(self, text: str) -> bool:
"""Check if the text contains a Qwen 2.5 format tool call."""
@@ -43,15 +47,64 @@ class Qwen25Detector(BaseFormatDetector):
normal_text = text[:idx].strip() if idx != -1 else text
if self.bot_token not in text:
return StreamingParseResult(normal_text=normal_text, calls=[])
pattern = rf"{self.bot_token}(.*?){self.eot_token}"
# Find all <tool_call>\n...\n</tool_call> blocks
pattern = rf"{re.escape(self.bot_token)}(.*?){re.escape(self.eot_token)}"
match_result_list = re.findall(pattern, text, re.DOTALL)
calls = []
for match_result in match_result_list:
match_result = json.loads(match_result)
calls.extend(self.parse_base_json(match_result, tools))
try:
parsed_call = json.loads(match_result.strip())
calls.extend(self.parse_base_json(parsed_call, tools))
except json.JSONDecodeError as e:
logger.warning(
f"Failed to parse JSON part: {match_result}, JSON parse error: {str(e)}"
)
continue
return StreamingParseResult(normal_text=normal_text, calls=calls)
def parse_streaming_increment(
self, new_text: str, tools: List[Tool]
) -> StreamingParseResult:
"""
Streaming incremental parsing for Qwen 2.5 tool calls.
Uses base class implementation with buffering to handle partial end tokens.
"""
result = super().parse_streaming_increment(new_text, tools)
# Handle partial end tokens that are streamed character by character
if result.normal_text:
self._normal_text_buffer += result.normal_text
# Check if buffer contains complete end token (without leading newline)
end_token_without_newline = self.eot_token[1:] # "</tool_call>"
if end_token_without_newline in self._normal_text_buffer:
cleaned_text = self._normal_text_buffer.replace(
end_token_without_newline, ""
)
self._normal_text_buffer = ""
result.normal_text = cleaned_text
else:
# Check if buffer might contain partial end token at the end
partial_match_len = self.ends_with_partial_token(
self._normal_text_buffer, end_token_without_newline
)
if partial_match_len:
# Keep potential partial match in buffer, return the rest
result.normal_text = self._normal_text_buffer[:-partial_match_len]
self._normal_text_buffer = self._normal_text_buffer[
-partial_match_len:
]
else:
# No partial match, return all buffered text
result.normal_text = self._normal_text_buffer
self._normal_text_buffer = ""
return result
def structure_info(self) -> _GetInfoFunc:
# TODO: Update the begin and end tokens with '\n' if necessary
return lambda name: StructureInfo(
begin='<tool_call>{"name":"' + name + '", "arguments":',
end="}</tool_call>",
@@ -61,7 +114,8 @@ class Qwen25Detector(BaseFormatDetector):
def build_ebnf(self, tools: List[Tool]):
return EBNFComposer.build_ebnf(
tools,
bot_token=self.bot_token,
eot_token=self.eot_token,
individual_call_start_token=self.bot_token.replace("\n", "\\n"),
individual_call_end_token=self.eot_token.replace("\n", "\\n"),
tool_call_separator="\\n",
function_format="json",
)