feat: Improve Mistral and Qwen25 function call parsing (#6597)
This commit is contained in:
@@ -72,20 +72,51 @@ class BaseFormatDetector(ABC):
|
||||
action = json.loads(text)
|
||||
return StreamingParseResult(calls=self.parse_base_json(action, tools))
|
||||
|
||||
def ends_with_partial_token(self, buffer: str, bot_token: str) -> int:
|
||||
"""
|
||||
Check if buffer ends with a partial bot_token.
|
||||
Return the length of the partial bot_token.
|
||||
|
||||
For some format, the bot_token is not a token in model's vocabulary, such as
|
||||
`[TOOL_CALLS] [` in Mistral.
|
||||
"""
|
||||
for i in range(1, min(len(buffer) + 1, len(bot_token))):
|
||||
if bot_token.startswith(buffer[-i:]):
|
||||
return i
|
||||
return 0
|
||||
|
||||
def parse_streaming_increment(
|
||||
self, new_text: str, tools: List[Tool]
|
||||
) -> StreamingParseResult:
|
||||
"""
|
||||
Streaming incremental parsing with tool validation.
|
||||
|
||||
This base implementation works best with formats where:
|
||||
1. bot_token is followed immediately by JSON (e.g., bot_token + JSON_array)
|
||||
2. JSON can be parsed incrementally using partial_json_loads
|
||||
3. Multiple tool calls are separated by "; " or ", "
|
||||
|
||||
Examples of incompatible formats (need custom implementation, may reuse some logic from this class):
|
||||
- Each tool call is wrapped in a separate block: See Qwen25Detector
|
||||
- Multiple separate blocks: [TOOL_CALLS] [...] \n [TOOL_CALLS] [...]
|
||||
- Tool call is Pythonic style
|
||||
|
||||
For incompatible formats, detectors should override this method with custom logic.
|
||||
"""
|
||||
# Append new text to buffer
|
||||
self._buffer += new_text
|
||||
current_text = self._buffer
|
||||
if not (self.bot_token in current_text or current_text.startswith("{")):
|
||||
self._buffer = ""
|
||||
if self.eot_token in new_text:
|
||||
new_text = new_text.replace(self.eot_token, "")
|
||||
return StreamingParseResult(normal_text=new_text)
|
||||
# Only clear buffer if we're sure no tool call is starting
|
||||
if not self.ends_with_partial_token(self._buffer, self.bot_token):
|
||||
normal_text = self._buffer
|
||||
self._buffer = ""
|
||||
if self.eot_token in normal_text:
|
||||
normal_text = normal_text.replace(self.eot_token, "")
|
||||
return StreamingParseResult(normal_text=normal_text)
|
||||
else:
|
||||
# Might be partial bot_token, keep buffering
|
||||
return StreamingParseResult()
|
||||
|
||||
# Build tool indices if not already built
|
||||
if not hasattr(self, "_tool_indices"):
|
||||
|
||||
@@ -149,8 +149,8 @@ class DeepSeekV3Detector(BaseFormatDetector):
|
||||
def build_ebnf(self, tools: List[Tool]):
|
||||
return EBNFComposer.build_ebnf(
|
||||
tools,
|
||||
bot_token=self.bot_token,
|
||||
eot_token=self.eot_token,
|
||||
sequence_start_token=self.bot_token,
|
||||
sequence_end_token=self.eot_token,
|
||||
tool_call_separator="",
|
||||
call_rule_fmt='"<|tool▁call▁begin|>function<|tool▁sep|>{name}\\n```json\\n" {arguments_rule} "\\n```<|tool▁call▁end|>"',
|
||||
function_format="json",
|
||||
|
||||
@@ -30,11 +30,6 @@ class EBNFComposer:
|
||||
ws ::= [ \n\t]*
|
||||
"""
|
||||
|
||||
TOOL_CALLS_MAP = {
|
||||
"pythonic": '"[" function_call ("," function_call)* "]"',
|
||||
"json": "function_call",
|
||||
}
|
||||
|
||||
CALL_RULE_MAP = {
|
||||
"pythonic": 'call_{name} ::= "{name}" "(" {arguments_rule} ")"',
|
||||
"json": 'call_{name} ::= "{{" "\\"name\\"" ":" "\\"{name}\\"" ", " "\\"arguments\\"" ":" {arguments_rule} "}}"',
|
||||
@@ -138,35 +133,54 @@ class EBNFComposer:
|
||||
@staticmethod
|
||||
def build_ebnf(
|
||||
tools,
|
||||
*,
|
||||
call_rule_fmt: Optional[str] = None,
|
||||
function_format: Literal["pythonic", "json"] = "json",
|
||||
bot_token: Optional[str] = None,
|
||||
eot_token: Optional[str] = None,
|
||||
# Parameters for wrapping the entire sequence of tool calls
|
||||
sequence_start_token: Optional[str] = None,
|
||||
sequence_end_token: Optional[str] = None,
|
||||
# Parameters for wrapping individual tool calls
|
||||
individual_call_start_token: Optional[str] = None,
|
||||
individual_call_end_token: Optional[str] = None,
|
||||
# Parameter for separating multiple tool calls
|
||||
tool_call_separator: Optional[str] = None,
|
||||
call_rule_fmt: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Generalized EBNF builder for all detectors.
|
||||
Args:
|
||||
tools: List of Tool objects to generate EBNF grammar for
|
||||
function_format: The format of function calls, either "pythonic" or "json"
|
||||
sequence_start_token: Token that wraps the entire sequence of tool calls (start)
|
||||
sequence_end_token: Token that wraps the entire sequence of tool calls (end)
|
||||
individual_call_start_token: Token that wraps each individual tool call (start)
|
||||
individual_call_end_token: Token that wraps each individual tool call (end)
|
||||
tool_call_separator: The separator between multiple tool calls
|
||||
call_rule_fmt: Optional custom format string for call_{name} rule. It should define each function call's format, with
|
||||
the placeholders {name} for the function name and {arguments_rule} for the arguments rule. If None, a default
|
||||
format based on function_format will be used.
|
||||
function_format: The format of function calls, either "pythonic" or "json"
|
||||
bot_token: The token that indicates the start of a tool call section
|
||||
eot_token: The token that indicates the end of a tool call section
|
||||
tool_call_separator: The separator between multiple tool calls
|
||||
"""
|
||||
# =================================================================
|
||||
# Step 1: Determine the root tool calls rule
|
||||
# =================================================================
|
||||
if bot_token and eot_token:
|
||||
if tool_call_separator:
|
||||
root_rule = f'"{bot_token}" function_call ( "{tool_call_separator}" function_call )* "{eot_token}"'
|
||||
else:
|
||||
root_rule = f'"{bot_token}" function_call "{eot_token}"'
|
||||
# Handle a single function call
|
||||
if individual_call_start_token and individual_call_end_token:
|
||||
function_call_unit = f'"{individual_call_start_token}" function_call "{individual_call_end_token}"'
|
||||
else:
|
||||
root_rule = EBNFComposer.TOOL_CALLS_MAP[function_format]
|
||||
function_call_unit = "function_call"
|
||||
|
||||
# Handle multiple function calls with separators
|
||||
if tool_call_separator is not None:
|
||||
base_pattern = f'{function_call_unit} ( "{tool_call_separator}" {function_call_unit} )*'
|
||||
else:
|
||||
# Assume only support single function call
|
||||
base_pattern = function_call_unit
|
||||
|
||||
# Apply sequence-level wrapping if needed
|
||||
if sequence_start_token and sequence_end_token:
|
||||
root_rule = (
|
||||
f'"{sequence_start_token}" {base_pattern} "{sequence_end_token}"'
|
||||
)
|
||||
else:
|
||||
root_rule = base_pattern
|
||||
|
||||
# =================================================================
|
||||
# Step 2: Build the header rules
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
@@ -11,12 +12,14 @@ from sglang.srt.function_call.core_types import (
|
||||
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
||||
from sglang.srt.openai_api.protocol import Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MistralDetector(BaseFormatDetector):
|
||||
"""
|
||||
Detector for Mistral models.
|
||||
Assumes function call format:
|
||||
[TOOL_CALLS] [{"name":"xxx", "arguments":{...}}]
|
||||
[TOOL_CALLS] [{"name":"func1", "arguments":{...}}, {"name":"func2", "arguments":{...}}]
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
@@ -32,21 +35,6 @@ class MistralDetector(BaseFormatDetector):
|
||||
"""Check if the text contains a Mistral format tool call."""
|
||||
return self.bot_token in text
|
||||
|
||||
def _clean_text(self, text: str) -> str:
|
||||
"""
|
||||
clean text to only leave ''[TOOL_CALLS] [{"name": xxx, "arguments": {xxx}}]'
|
||||
for example,
|
||||
text = '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]\n\nToday\'s weather in Boston is :{function call result} (in Fahrenheit)\n\nIf you prefer Celsius, please let me know.'
|
||||
return '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]'
|
||||
The key pattern is [TOOL_CALLS] [...]
|
||||
"""
|
||||
# TODO: check if Mistral supports multiple tool calls, currently assume only support one tool call
|
||||
find_results = re.findall(r"\[TOOL_CALLS\] \[.*?\]", text, re.DOTALL)
|
||||
if len(find_results) > 0:
|
||||
return find_results[0]
|
||||
else:
|
||||
return ""
|
||||
|
||||
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||||
"""
|
||||
One-time parsing: Detects and parses tool calls in the provided text.
|
||||
@@ -57,17 +45,74 @@ class MistralDetector(BaseFormatDetector):
|
||||
"""
|
||||
idx = text.find(self.bot_token)
|
||||
normal_text = text[:idx].strip() if idx != -1 else text
|
||||
text = self._clean_text(text)
|
||||
tool_content = text.replace("[TOOL_CALLS]", "").strip()
|
||||
raw_tool_calls = self.tool_call_regex.findall(tool_content)
|
||||
|
||||
if self.bot_token not in text:
|
||||
return StreamingParseResult(normal_text=normal_text, calls=[])
|
||||
|
||||
# Extract the JSON array part from [TOOL_CALLS] [...]
|
||||
# Use bracket counting to properly handle nested brackets in JSON content
|
||||
json_array_str = self._extract_json_array(text)
|
||||
if not json_array_str:
|
||||
return StreamingParseResult(normal_text=normal_text, calls=[])
|
||||
|
||||
calls = []
|
||||
if len(raw_tool_calls) > 0:
|
||||
raw_tool_call = raw_tool_calls[0]
|
||||
function_call_arr = json.loads(raw_tool_call)
|
||||
for match_result in function_call_arr:
|
||||
calls.extend(self.parse_base_json(match_result, tools))
|
||||
try:
|
||||
function_call_arr = json.loads(json_array_str)
|
||||
# Handle both single object and array of objects
|
||||
if not isinstance(function_call_arr, list):
|
||||
function_call_arr = [function_call_arr]
|
||||
calls = self.parse_base_json(function_call_arr, tools)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(
|
||||
f"Failed to parse JSON part: {json_array_str}, JSON parse error: {str(e)}"
|
||||
)
|
||||
|
||||
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
||||
|
||||
def _extract_json_array(self, text: str) -> str:
|
||||
"""
|
||||
Extract the JSON array part using bracket counting to handle nested brackets.
|
||||
|
||||
:param text: The complete text containing [TOOL_CALLS] [...]
|
||||
:return: The JSON array string or None if not found
|
||||
"""
|
||||
start_idx = text.find(self.bot_token)
|
||||
if start_idx == -1:
|
||||
return None
|
||||
|
||||
# Start from the opening bracket after [TOOL_CALLS]
|
||||
json_start = (
|
||||
start_idx + len(self.bot_token) - 1
|
||||
) # -1 to include the opening bracket
|
||||
bracket_count = 0
|
||||
in_string = False
|
||||
escape_next = False
|
||||
|
||||
for i in range(json_start, len(text)):
|
||||
char = text[i]
|
||||
|
||||
if escape_next:
|
||||
escape_next = False
|
||||
continue
|
||||
|
||||
if char == "\\":
|
||||
escape_next = True
|
||||
continue
|
||||
|
||||
if char == '"' and not escape_next:
|
||||
in_string = not in_string
|
||||
continue
|
||||
|
||||
if not in_string:
|
||||
if char == "[":
|
||||
bracket_count += 1
|
||||
elif char == "]":
|
||||
bracket_count -= 1
|
||||
if bracket_count == 0:
|
||||
return text[json_start : i + 1]
|
||||
|
||||
return None
|
||||
|
||||
def structure_info(self) -> _GetInfoFunc:
|
||||
return lambda name: StructureInfo(
|
||||
begin='[TOOL_CALLS] [{"name":"' + name + '", "arguments":',
|
||||
@@ -78,7 +123,8 @@ class MistralDetector(BaseFormatDetector):
|
||||
def build_ebnf(self, tools: List[Tool]):
|
||||
return EBNFComposer.build_ebnf(
|
||||
tools,
|
||||
bot_token=self.bot_token,
|
||||
eot_token=self.eot_token,
|
||||
sequence_start_token=self.bot_token,
|
||||
sequence_end_token=self.eot_token,
|
||||
function_format="json",
|
||||
tool_call_separator=", ",
|
||||
)
|
||||
|
||||
@@ -156,8 +156,8 @@ class PythonicDetector(BaseFormatDetector):
|
||||
def build_ebnf(self, tools: List[Tool]) -> Optional[str]:
|
||||
return EBNFComposer.build_ebnf(
|
||||
tools,
|
||||
bot_token="[",
|
||||
eot_token="]",
|
||||
sequence_start_token="[",
|
||||
sequence_end_token="]",
|
||||
tool_call_separator=",",
|
||||
function_format="pythonic",
|
||||
)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
@@ -11,12 +12,14 @@ from sglang.srt.function_call.core_types import (
|
||||
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
||||
from sglang.srt.openai_api.protocol import Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Qwen25Detector(BaseFormatDetector):
|
||||
"""
|
||||
Detector for Qwen 2.5 models.
|
||||
Assumes function call format:
|
||||
<tool_call>{"name":"xxx", "arguments":{...}}</tool_call>
|
||||
<tool_call>\n{"name":"func1", "arguments":{...}}\n</tool_call>\n<tool_call>\n{"name":"func2", "arguments":{...}}\n</tool_call>
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
@@ -24,8 +27,9 @@ class Qwen25Detector(BaseFormatDetector):
|
||||
Initializes the detector with necessary state variables.
|
||||
"""
|
||||
super().__init__()
|
||||
self.bot_token = "<tool_call>"
|
||||
self.eot_token = "</tool_call>"
|
||||
self.bot_token = "<tool_call>\n"
|
||||
self.eot_token = "\n</tool_call>"
|
||||
self._normal_text_buffer = "" # Buffer for handling partial end tokens
|
||||
|
||||
def has_tool_call(self, text: str) -> bool:
|
||||
"""Check if the text contains a Qwen 2.5 format tool call."""
|
||||
@@ -43,15 +47,64 @@ class Qwen25Detector(BaseFormatDetector):
|
||||
normal_text = text[:idx].strip() if idx != -1 else text
|
||||
if self.bot_token not in text:
|
||||
return StreamingParseResult(normal_text=normal_text, calls=[])
|
||||
pattern = rf"{self.bot_token}(.*?){self.eot_token}"
|
||||
|
||||
# Find all <tool_call>\n...\n</tool_call> blocks
|
||||
pattern = rf"{re.escape(self.bot_token)}(.*?){re.escape(self.eot_token)}"
|
||||
match_result_list = re.findall(pattern, text, re.DOTALL)
|
||||
calls = []
|
||||
for match_result in match_result_list:
|
||||
match_result = json.loads(match_result)
|
||||
calls.extend(self.parse_base_json(match_result, tools))
|
||||
try:
|
||||
parsed_call = json.loads(match_result.strip())
|
||||
calls.extend(self.parse_base_json(parsed_call, tools))
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(
|
||||
f"Failed to parse JSON part: {match_result}, JSON parse error: {str(e)}"
|
||||
)
|
||||
continue
|
||||
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
||||
|
||||
def parse_streaming_increment(
|
||||
self, new_text: str, tools: List[Tool]
|
||||
) -> StreamingParseResult:
|
||||
"""
|
||||
Streaming incremental parsing for Qwen 2.5 tool calls.
|
||||
Uses base class implementation with buffering to handle partial end tokens.
|
||||
"""
|
||||
result = super().parse_streaming_increment(new_text, tools)
|
||||
|
||||
# Handle partial end tokens that are streamed character by character
|
||||
if result.normal_text:
|
||||
self._normal_text_buffer += result.normal_text
|
||||
|
||||
# Check if buffer contains complete end token (without leading newline)
|
||||
end_token_without_newline = self.eot_token[1:] # "</tool_call>"
|
||||
if end_token_without_newline in self._normal_text_buffer:
|
||||
cleaned_text = self._normal_text_buffer.replace(
|
||||
end_token_without_newline, ""
|
||||
)
|
||||
self._normal_text_buffer = ""
|
||||
result.normal_text = cleaned_text
|
||||
else:
|
||||
# Check if buffer might contain partial end token at the end
|
||||
partial_match_len = self.ends_with_partial_token(
|
||||
self._normal_text_buffer, end_token_without_newline
|
||||
)
|
||||
|
||||
if partial_match_len:
|
||||
# Keep potential partial match in buffer, return the rest
|
||||
result.normal_text = self._normal_text_buffer[:-partial_match_len]
|
||||
self._normal_text_buffer = self._normal_text_buffer[
|
||||
-partial_match_len:
|
||||
]
|
||||
else:
|
||||
# No partial match, return all buffered text
|
||||
result.normal_text = self._normal_text_buffer
|
||||
self._normal_text_buffer = ""
|
||||
|
||||
return result
|
||||
|
||||
def structure_info(self) -> _GetInfoFunc:
|
||||
# TODO: Update the begin and end tokens with '\n' if necessary
|
||||
return lambda name: StructureInfo(
|
||||
begin='<tool_call>{"name":"' + name + '", "arguments":',
|
||||
end="}</tool_call>",
|
||||
@@ -61,7 +114,8 @@ class Qwen25Detector(BaseFormatDetector):
|
||||
def build_ebnf(self, tools: List[Tool]):
|
||||
return EBNFComposer.build_ebnf(
|
||||
tools,
|
||||
bot_token=self.bot_token,
|
||||
eot_token=self.eot_token,
|
||||
individual_call_start_token=self.bot_token.replace("\n", "\\n"),
|
||||
individual_call_end_token=self.eot_token.replace("\n", "\\n"),
|
||||
tool_call_separator="\\n",
|
||||
function_format="json",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user