feat: Improve Mistral and Qwen25 function call parsing (#6597)
This commit is contained in:
@@ -72,20 +72,51 @@ class BaseFormatDetector(ABC):
|
|||||||
action = json.loads(text)
|
action = json.loads(text)
|
||||||
return StreamingParseResult(calls=self.parse_base_json(action, tools))
|
return StreamingParseResult(calls=self.parse_base_json(action, tools))
|
||||||
|
|
||||||
|
def ends_with_partial_token(self, buffer: str, bot_token: str) -> int:
|
||||||
|
"""
|
||||||
|
Check if buffer ends with a partial bot_token.
|
||||||
|
Return the length of the partial bot_token.
|
||||||
|
|
||||||
|
For some format, the bot_token is not a token in model's vocabulary, such as
|
||||||
|
`[TOOL_CALLS] [` in Mistral.
|
||||||
|
"""
|
||||||
|
for i in range(1, min(len(buffer) + 1, len(bot_token))):
|
||||||
|
if bot_token.startswith(buffer[-i:]):
|
||||||
|
return i
|
||||||
|
return 0
|
||||||
|
|
||||||
def parse_streaming_increment(
|
def parse_streaming_increment(
|
||||||
self, new_text: str, tools: List[Tool]
|
self, new_text: str, tools: List[Tool]
|
||||||
) -> StreamingParseResult:
|
) -> StreamingParseResult:
|
||||||
"""
|
"""
|
||||||
Streaming incremental parsing with tool validation.
|
Streaming incremental parsing with tool validation.
|
||||||
|
|
||||||
|
This base implementation works best with formats where:
|
||||||
|
1. bot_token is followed immediately by JSON (e.g., bot_token + JSON_array)
|
||||||
|
2. JSON can be parsed incrementally using partial_json_loads
|
||||||
|
3. Multiple tool calls are separated by "; " or ", "
|
||||||
|
|
||||||
|
Examples of incompatible formats (need custom implementation, may reuse some logic from this class):
|
||||||
|
- Each tool call is wrapped in a separate block: See Qwen25Detector
|
||||||
|
- Multiple separate blocks: [TOOL_CALLS] [...] \n [TOOL_CALLS] [...]
|
||||||
|
- Tool call is Pythonic style
|
||||||
|
|
||||||
|
For incompatible formats, detectors should override this method with custom logic.
|
||||||
"""
|
"""
|
||||||
# Append new text to buffer
|
# Append new text to buffer
|
||||||
self._buffer += new_text
|
self._buffer += new_text
|
||||||
current_text = self._buffer
|
current_text = self._buffer
|
||||||
if not (self.bot_token in current_text or current_text.startswith("{")):
|
if not (self.bot_token in current_text or current_text.startswith("{")):
|
||||||
self._buffer = ""
|
# Only clear buffer if we're sure no tool call is starting
|
||||||
if self.eot_token in new_text:
|
if not self.ends_with_partial_token(self._buffer, self.bot_token):
|
||||||
new_text = new_text.replace(self.eot_token, "")
|
normal_text = self._buffer
|
||||||
return StreamingParseResult(normal_text=new_text)
|
self._buffer = ""
|
||||||
|
if self.eot_token in normal_text:
|
||||||
|
normal_text = normal_text.replace(self.eot_token, "")
|
||||||
|
return StreamingParseResult(normal_text=normal_text)
|
||||||
|
else:
|
||||||
|
# Might be partial bot_token, keep buffering
|
||||||
|
return StreamingParseResult()
|
||||||
|
|
||||||
# Build tool indices if not already built
|
# Build tool indices if not already built
|
||||||
if not hasattr(self, "_tool_indices"):
|
if not hasattr(self, "_tool_indices"):
|
||||||
|
|||||||
@@ -149,8 +149,8 @@ class DeepSeekV3Detector(BaseFormatDetector):
|
|||||||
def build_ebnf(self, tools: List[Tool]):
|
def build_ebnf(self, tools: List[Tool]):
|
||||||
return EBNFComposer.build_ebnf(
|
return EBNFComposer.build_ebnf(
|
||||||
tools,
|
tools,
|
||||||
bot_token=self.bot_token,
|
sequence_start_token=self.bot_token,
|
||||||
eot_token=self.eot_token,
|
sequence_end_token=self.eot_token,
|
||||||
tool_call_separator="",
|
tool_call_separator="",
|
||||||
call_rule_fmt='"<|tool▁call▁begin|>function<|tool▁sep|>{name}\\n```json\\n" {arguments_rule} "\\n```<|tool▁call▁end|>"',
|
call_rule_fmt='"<|tool▁call▁begin|>function<|tool▁sep|>{name}\\n```json\\n" {arguments_rule} "\\n```<|tool▁call▁end|>"',
|
||||||
function_format="json",
|
function_format="json",
|
||||||
|
|||||||
@@ -30,11 +30,6 @@ class EBNFComposer:
|
|||||||
ws ::= [ \n\t]*
|
ws ::= [ \n\t]*
|
||||||
"""
|
"""
|
||||||
|
|
||||||
TOOL_CALLS_MAP = {
|
|
||||||
"pythonic": '"[" function_call ("," function_call)* "]"',
|
|
||||||
"json": "function_call",
|
|
||||||
}
|
|
||||||
|
|
||||||
CALL_RULE_MAP = {
|
CALL_RULE_MAP = {
|
||||||
"pythonic": 'call_{name} ::= "{name}" "(" {arguments_rule} ")"',
|
"pythonic": 'call_{name} ::= "{name}" "(" {arguments_rule} ")"',
|
||||||
"json": 'call_{name} ::= "{{" "\\"name\\"" ":" "\\"{name}\\"" ", " "\\"arguments\\"" ":" {arguments_rule} "}}"',
|
"json": 'call_{name} ::= "{{" "\\"name\\"" ":" "\\"{name}\\"" ", " "\\"arguments\\"" ":" {arguments_rule} "}}"',
|
||||||
@@ -138,35 +133,54 @@ class EBNFComposer:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def build_ebnf(
|
def build_ebnf(
|
||||||
tools,
|
tools,
|
||||||
*,
|
|
||||||
call_rule_fmt: Optional[str] = None,
|
|
||||||
function_format: Literal["pythonic", "json"] = "json",
|
function_format: Literal["pythonic", "json"] = "json",
|
||||||
bot_token: Optional[str] = None,
|
# Parameters for wrapping the entire sequence of tool calls
|
||||||
eot_token: Optional[str] = None,
|
sequence_start_token: Optional[str] = None,
|
||||||
|
sequence_end_token: Optional[str] = None,
|
||||||
|
# Parameters for wrapping individual tool calls
|
||||||
|
individual_call_start_token: Optional[str] = None,
|
||||||
|
individual_call_end_token: Optional[str] = None,
|
||||||
|
# Parameter for separating multiple tool calls
|
||||||
tool_call_separator: Optional[str] = None,
|
tool_call_separator: Optional[str] = None,
|
||||||
|
call_rule_fmt: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Generalized EBNF builder for all detectors.
|
Generalized EBNF builder for all detectors.
|
||||||
Args:
|
Args:
|
||||||
tools: List of Tool objects to generate EBNF grammar for
|
tools: List of Tool objects to generate EBNF grammar for
|
||||||
|
function_format: The format of function calls, either "pythonic" or "json"
|
||||||
|
sequence_start_token: Token that wraps the entire sequence of tool calls (start)
|
||||||
|
sequence_end_token: Token that wraps the entire sequence of tool calls (end)
|
||||||
|
individual_call_start_token: Token that wraps each individual tool call (start)
|
||||||
|
individual_call_end_token: Token that wraps each individual tool call (end)
|
||||||
|
tool_call_separator: The separator between multiple tool calls
|
||||||
call_rule_fmt: Optional custom format string for call_{name} rule. It should define each function call's format, with
|
call_rule_fmt: Optional custom format string for call_{name} rule. It should define each function call's format, with
|
||||||
the placeholders {name} for the function name and {arguments_rule} for the arguments rule. If None, a default
|
the placeholders {name} for the function name and {arguments_rule} for the arguments rule. If None, a default
|
||||||
format based on function_format will be used.
|
format based on function_format will be used.
|
||||||
function_format: The format of function calls, either "pythonic" or "json"
|
|
||||||
bot_token: The token that indicates the start of a tool call section
|
|
||||||
eot_token: The token that indicates the end of a tool call section
|
|
||||||
tool_call_separator: The separator between multiple tool calls
|
|
||||||
"""
|
"""
|
||||||
# =================================================================
|
# =================================================================
|
||||||
# Step 1: Determine the root tool calls rule
|
# Step 1: Determine the root tool calls rule
|
||||||
# =================================================================
|
# =================================================================
|
||||||
if bot_token and eot_token:
|
# Handle a single function call
|
||||||
if tool_call_separator:
|
if individual_call_start_token and individual_call_end_token:
|
||||||
root_rule = f'"{bot_token}" function_call ( "{tool_call_separator}" function_call )* "{eot_token}"'
|
function_call_unit = f'"{individual_call_start_token}" function_call "{individual_call_end_token}"'
|
||||||
else:
|
|
||||||
root_rule = f'"{bot_token}" function_call "{eot_token}"'
|
|
||||||
else:
|
else:
|
||||||
root_rule = EBNFComposer.TOOL_CALLS_MAP[function_format]
|
function_call_unit = "function_call"
|
||||||
|
|
||||||
|
# Handle multiple function calls with separators
|
||||||
|
if tool_call_separator is not None:
|
||||||
|
base_pattern = f'{function_call_unit} ( "{tool_call_separator}" {function_call_unit} )*'
|
||||||
|
else:
|
||||||
|
# Assume only support single function call
|
||||||
|
base_pattern = function_call_unit
|
||||||
|
|
||||||
|
# Apply sequence-level wrapping if needed
|
||||||
|
if sequence_start_token and sequence_end_token:
|
||||||
|
root_rule = (
|
||||||
|
f'"{sequence_start_token}" {base_pattern} "{sequence_end_token}"'
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
root_rule = base_pattern
|
||||||
|
|
||||||
# =================================================================
|
# =================================================================
|
||||||
# Step 2: Build the header rules
|
# Step 2: Build the header rules
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
@@ -11,12 +12,14 @@ from sglang.srt.function_call.core_types import (
|
|||||||
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
||||||
from sglang.srt.openai_api.protocol import Tool
|
from sglang.srt.openai_api.protocol import Tool
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MistralDetector(BaseFormatDetector):
|
class MistralDetector(BaseFormatDetector):
|
||||||
"""
|
"""
|
||||||
Detector for Mistral models.
|
Detector for Mistral models.
|
||||||
Assumes function call format:
|
Assumes function call format:
|
||||||
[TOOL_CALLS] [{"name":"xxx", "arguments":{...}}]
|
[TOOL_CALLS] [{"name":"func1", "arguments":{...}}, {"name":"func2", "arguments":{...}}]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -32,21 +35,6 @@ class MistralDetector(BaseFormatDetector):
|
|||||||
"""Check if the text contains a Mistral format tool call."""
|
"""Check if the text contains a Mistral format tool call."""
|
||||||
return self.bot_token in text
|
return self.bot_token in text
|
||||||
|
|
||||||
def _clean_text(self, text: str) -> str:
|
|
||||||
"""
|
|
||||||
clean text to only leave ''[TOOL_CALLS] [{"name": xxx, "arguments": {xxx}}]'
|
|
||||||
for example,
|
|
||||||
text = '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]\n\nToday\'s weather in Boston is :{function call result} (in Fahrenheit)\n\nIf you prefer Celsius, please let me know.'
|
|
||||||
return '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]'
|
|
||||||
The key pattern is [TOOL_CALLS] [...]
|
|
||||||
"""
|
|
||||||
# TODO: check if Mistral supports multiple tool calls, currently assume only support one tool call
|
|
||||||
find_results = re.findall(r"\[TOOL_CALLS\] \[.*?\]", text, re.DOTALL)
|
|
||||||
if len(find_results) > 0:
|
|
||||||
return find_results[0]
|
|
||||||
else:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||||||
"""
|
"""
|
||||||
One-time parsing: Detects and parses tool calls in the provided text.
|
One-time parsing: Detects and parses tool calls in the provided text.
|
||||||
@@ -57,17 +45,74 @@ class MistralDetector(BaseFormatDetector):
|
|||||||
"""
|
"""
|
||||||
idx = text.find(self.bot_token)
|
idx = text.find(self.bot_token)
|
||||||
normal_text = text[:idx].strip() if idx != -1 else text
|
normal_text = text[:idx].strip() if idx != -1 else text
|
||||||
text = self._clean_text(text)
|
|
||||||
tool_content = text.replace("[TOOL_CALLS]", "").strip()
|
if self.bot_token not in text:
|
||||||
raw_tool_calls = self.tool_call_regex.findall(tool_content)
|
return StreamingParseResult(normal_text=normal_text, calls=[])
|
||||||
|
|
||||||
|
# Extract the JSON array part from [TOOL_CALLS] [...]
|
||||||
|
# Use bracket counting to properly handle nested brackets in JSON content
|
||||||
|
json_array_str = self._extract_json_array(text)
|
||||||
|
if not json_array_str:
|
||||||
|
return StreamingParseResult(normal_text=normal_text, calls=[])
|
||||||
|
|
||||||
calls = []
|
calls = []
|
||||||
if len(raw_tool_calls) > 0:
|
try:
|
||||||
raw_tool_call = raw_tool_calls[0]
|
function_call_arr = json.loads(json_array_str)
|
||||||
function_call_arr = json.loads(raw_tool_call)
|
# Handle both single object and array of objects
|
||||||
for match_result in function_call_arr:
|
if not isinstance(function_call_arr, list):
|
||||||
calls.extend(self.parse_base_json(match_result, tools))
|
function_call_arr = [function_call_arr]
|
||||||
|
calls = self.parse_base_json(function_call_arr, tools)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to parse JSON part: {json_array_str}, JSON parse error: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
||||||
|
|
||||||
|
def _extract_json_array(self, text: str) -> str:
|
||||||
|
"""
|
||||||
|
Extract the JSON array part using bracket counting to handle nested brackets.
|
||||||
|
|
||||||
|
:param text: The complete text containing [TOOL_CALLS] [...]
|
||||||
|
:return: The JSON array string or None if not found
|
||||||
|
"""
|
||||||
|
start_idx = text.find(self.bot_token)
|
||||||
|
if start_idx == -1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Start from the opening bracket after [TOOL_CALLS]
|
||||||
|
json_start = (
|
||||||
|
start_idx + len(self.bot_token) - 1
|
||||||
|
) # -1 to include the opening bracket
|
||||||
|
bracket_count = 0
|
||||||
|
in_string = False
|
||||||
|
escape_next = False
|
||||||
|
|
||||||
|
for i in range(json_start, len(text)):
|
||||||
|
char = text[i]
|
||||||
|
|
||||||
|
if escape_next:
|
||||||
|
escape_next = False
|
||||||
|
continue
|
||||||
|
|
||||||
|
if char == "\\":
|
||||||
|
escape_next = True
|
||||||
|
continue
|
||||||
|
|
||||||
|
if char == '"' and not escape_next:
|
||||||
|
in_string = not in_string
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not in_string:
|
||||||
|
if char == "[":
|
||||||
|
bracket_count += 1
|
||||||
|
elif char == "]":
|
||||||
|
bracket_count -= 1
|
||||||
|
if bracket_count == 0:
|
||||||
|
return text[json_start : i + 1]
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def structure_info(self) -> _GetInfoFunc:
|
def structure_info(self) -> _GetInfoFunc:
|
||||||
return lambda name: StructureInfo(
|
return lambda name: StructureInfo(
|
||||||
begin='[TOOL_CALLS] [{"name":"' + name + '", "arguments":',
|
begin='[TOOL_CALLS] [{"name":"' + name + '", "arguments":',
|
||||||
@@ -78,7 +123,8 @@ class MistralDetector(BaseFormatDetector):
|
|||||||
def build_ebnf(self, tools: List[Tool]):
|
def build_ebnf(self, tools: List[Tool]):
|
||||||
return EBNFComposer.build_ebnf(
|
return EBNFComposer.build_ebnf(
|
||||||
tools,
|
tools,
|
||||||
bot_token=self.bot_token,
|
sequence_start_token=self.bot_token,
|
||||||
eot_token=self.eot_token,
|
sequence_end_token=self.eot_token,
|
||||||
function_format="json",
|
function_format="json",
|
||||||
|
tool_call_separator=", ",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -156,8 +156,8 @@ class PythonicDetector(BaseFormatDetector):
|
|||||||
def build_ebnf(self, tools: List[Tool]) -> Optional[str]:
|
def build_ebnf(self, tools: List[Tool]) -> Optional[str]:
|
||||||
return EBNFComposer.build_ebnf(
|
return EBNFComposer.build_ebnf(
|
||||||
tools,
|
tools,
|
||||||
bot_token="[",
|
sequence_start_token="[",
|
||||||
eot_token="]",
|
sequence_end_token="]",
|
||||||
tool_call_separator=",",
|
tool_call_separator=",",
|
||||||
function_format="pythonic",
|
function_format="pythonic",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
@@ -11,12 +12,14 @@ from sglang.srt.function_call.core_types import (
|
|||||||
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
||||||
from sglang.srt.openai_api.protocol import Tool
|
from sglang.srt.openai_api.protocol import Tool
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Qwen25Detector(BaseFormatDetector):
|
class Qwen25Detector(BaseFormatDetector):
|
||||||
"""
|
"""
|
||||||
Detector for Qwen 2.5 models.
|
Detector for Qwen 2.5 models.
|
||||||
Assumes function call format:
|
Assumes function call format:
|
||||||
<tool_call>{"name":"xxx", "arguments":{...}}</tool_call>
|
<tool_call>\n{"name":"func1", "arguments":{...}}\n</tool_call>\n<tool_call>\n{"name":"func2", "arguments":{...}}\n</tool_call>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -24,8 +27,9 @@ class Qwen25Detector(BaseFormatDetector):
|
|||||||
Initializes the detector with necessary state variables.
|
Initializes the detector with necessary state variables.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.bot_token = "<tool_call>"
|
self.bot_token = "<tool_call>\n"
|
||||||
self.eot_token = "</tool_call>"
|
self.eot_token = "\n</tool_call>"
|
||||||
|
self._normal_text_buffer = "" # Buffer for handling partial end tokens
|
||||||
|
|
||||||
def has_tool_call(self, text: str) -> bool:
|
def has_tool_call(self, text: str) -> bool:
|
||||||
"""Check if the text contains a Qwen 2.5 format tool call."""
|
"""Check if the text contains a Qwen 2.5 format tool call."""
|
||||||
@@ -43,15 +47,64 @@ class Qwen25Detector(BaseFormatDetector):
|
|||||||
normal_text = text[:idx].strip() if idx != -1 else text
|
normal_text = text[:idx].strip() if idx != -1 else text
|
||||||
if self.bot_token not in text:
|
if self.bot_token not in text:
|
||||||
return StreamingParseResult(normal_text=normal_text, calls=[])
|
return StreamingParseResult(normal_text=normal_text, calls=[])
|
||||||
pattern = rf"{self.bot_token}(.*?){self.eot_token}"
|
|
||||||
|
# Find all <tool_call>\n...\n</tool_call> blocks
|
||||||
|
pattern = rf"{re.escape(self.bot_token)}(.*?){re.escape(self.eot_token)}"
|
||||||
match_result_list = re.findall(pattern, text, re.DOTALL)
|
match_result_list = re.findall(pattern, text, re.DOTALL)
|
||||||
calls = []
|
calls = []
|
||||||
for match_result in match_result_list:
|
for match_result in match_result_list:
|
||||||
match_result = json.loads(match_result)
|
try:
|
||||||
calls.extend(self.parse_base_json(match_result, tools))
|
parsed_call = json.loads(match_result.strip())
|
||||||
|
calls.extend(self.parse_base_json(parsed_call, tools))
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to parse JSON part: {match_result}, JSON parse error: {str(e)}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
||||||
|
|
||||||
|
def parse_streaming_increment(
|
||||||
|
self, new_text: str, tools: List[Tool]
|
||||||
|
) -> StreamingParseResult:
|
||||||
|
"""
|
||||||
|
Streaming incremental parsing for Qwen 2.5 tool calls.
|
||||||
|
Uses base class implementation with buffering to handle partial end tokens.
|
||||||
|
"""
|
||||||
|
result = super().parse_streaming_increment(new_text, tools)
|
||||||
|
|
||||||
|
# Handle partial end tokens that are streamed character by character
|
||||||
|
if result.normal_text:
|
||||||
|
self._normal_text_buffer += result.normal_text
|
||||||
|
|
||||||
|
# Check if buffer contains complete end token (without leading newline)
|
||||||
|
end_token_without_newline = self.eot_token[1:] # "</tool_call>"
|
||||||
|
if end_token_without_newline in self._normal_text_buffer:
|
||||||
|
cleaned_text = self._normal_text_buffer.replace(
|
||||||
|
end_token_without_newline, ""
|
||||||
|
)
|
||||||
|
self._normal_text_buffer = ""
|
||||||
|
result.normal_text = cleaned_text
|
||||||
|
else:
|
||||||
|
# Check if buffer might contain partial end token at the end
|
||||||
|
partial_match_len = self.ends_with_partial_token(
|
||||||
|
self._normal_text_buffer, end_token_without_newline
|
||||||
|
)
|
||||||
|
|
||||||
|
if partial_match_len:
|
||||||
|
# Keep potential partial match in buffer, return the rest
|
||||||
|
result.normal_text = self._normal_text_buffer[:-partial_match_len]
|
||||||
|
self._normal_text_buffer = self._normal_text_buffer[
|
||||||
|
-partial_match_len:
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
# No partial match, return all buffered text
|
||||||
|
result.normal_text = self._normal_text_buffer
|
||||||
|
self._normal_text_buffer = ""
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
def structure_info(self) -> _GetInfoFunc:
|
def structure_info(self) -> _GetInfoFunc:
|
||||||
|
# TODO: Update the begin and end tokens with '\n' if necessary
|
||||||
return lambda name: StructureInfo(
|
return lambda name: StructureInfo(
|
||||||
begin='<tool_call>{"name":"' + name + '", "arguments":',
|
begin='<tool_call>{"name":"' + name + '", "arguments":',
|
||||||
end="}</tool_call>",
|
end="}</tool_call>",
|
||||||
@@ -61,7 +114,8 @@ class Qwen25Detector(BaseFormatDetector):
|
|||||||
def build_ebnf(self, tools: List[Tool]):
|
def build_ebnf(self, tools: List[Tool]):
|
||||||
return EBNFComposer.build_ebnf(
|
return EBNFComposer.build_ebnf(
|
||||||
tools,
|
tools,
|
||||||
bot_token=self.bot_token,
|
individual_call_start_token=self.bot_token.replace("\n", "\\n"),
|
||||||
eot_token=self.eot_token,
|
individual_call_end_token=self.eot_token.replace("\n", "\\n"),
|
||||||
|
tool_call_separator="\\n",
|
||||||
function_format="json",
|
function_format="json",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -265,6 +265,118 @@ class TestPythonicDetector(unittest.TestCase):
|
|||||||
self.assertEqual(params["data"], [1, 2, 3])
|
self.assertEqual(params["data"], [1, 2, 3])
|
||||||
|
|
||||||
|
|
||||||
|
class TestMistralDetector(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
"""Set up test tools and detector for Mistral format testing."""
|
||||||
|
self.tools = [
|
||||||
|
Tool(
|
||||||
|
type="function",
|
||||||
|
function=Function(
|
||||||
|
name="make_next_step_decision",
|
||||||
|
description="Test function for decision making",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"decision": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The next step to take",
|
||||||
|
},
|
||||||
|
"content": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The content of the next step",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["decision", "content"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
self.detector = MistralDetector()
|
||||||
|
|
||||||
|
def test_detect_and_parse_with_nested_brackets_in_content(self):
|
||||||
|
"""Test parsing Mistral format with nested brackets in JSON content.
|
||||||
|
|
||||||
|
This test case specifically addresses the issue where the regex pattern
|
||||||
|
was incorrectly truncating JSON when it contained nested brackets like [City Name].
|
||||||
|
"""
|
||||||
|
# This is the exact problematic text from the original test failure
|
||||||
|
test_text = '[TOOL_CALLS] [{"name":"make_next_step_decision", "arguments":{"decision":"","content":"```\\nTOOL: Access a weather API or service\\nOBSERVATION: Retrieve the current weather data for the top 5 populated cities in the US\\nANSWER: The weather in the top 5 populated cities in the US is as follows: [City Name] - [Weather Conditions] - [Temperature]\\n```"}}]'
|
||||||
|
|
||||||
|
result = self.detector.detect_and_parse(test_text, self.tools)
|
||||||
|
|
||||||
|
# Verify that the parsing was successful
|
||||||
|
self.assertEqual(len(result.calls), 1, "Should detect exactly one tool call")
|
||||||
|
|
||||||
|
call = result.calls[0]
|
||||||
|
self.assertEqual(
|
||||||
|
call.name,
|
||||||
|
"make_next_step_decision",
|
||||||
|
"Should detect the correct function name",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify that the parameters are valid JSON and contain the expected content
|
||||||
|
params = json.loads(call.parameters)
|
||||||
|
self.assertEqual(
|
||||||
|
params["decision"], "", "Decision parameter should be empty string"
|
||||||
|
)
|
||||||
|
|
||||||
|
# The content should contain the full text including the nested brackets [City Name]
|
||||||
|
expected_content = "```\nTOOL: Access a weather API or service\nOBSERVATION: Retrieve the current weather data for the top 5 populated cities in the US\nANSWER: The weather in the top 5 populated cities in the US is as follows: [City Name] - [Weather Conditions] - [Temperature]\n```"
|
||||||
|
self.assertEqual(
|
||||||
|
params["content"],
|
||||||
|
expected_content,
|
||||||
|
"Content should include nested brackets without truncation",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify that normal text is empty (since the entire input is a tool call)
|
||||||
|
self.assertEqual(
|
||||||
|
result.normal_text, "", "Normal text should be empty for pure tool call"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_detect_and_parse_simple_case(self):
|
||||||
|
"""Test parsing a simple Mistral format tool call without nested brackets."""
|
||||||
|
test_text = '[TOOL_CALLS] [{"name":"make_next_step_decision", "arguments":{"decision":"TOOL", "content":"Use weather API"}}]'
|
||||||
|
|
||||||
|
result = self.detector.detect_and_parse(test_text, self.tools)
|
||||||
|
|
||||||
|
self.assertEqual(len(result.calls), 1)
|
||||||
|
call = result.calls[0]
|
||||||
|
self.assertEqual(call.name, "make_next_step_decision")
|
||||||
|
|
||||||
|
params = json.loads(call.parameters)
|
||||||
|
self.assertEqual(params["decision"], "TOOL")
|
||||||
|
self.assertEqual(params["content"], "Use weather API")
|
||||||
|
|
||||||
|
def test_detect_and_parse_no_tool_calls(self):
|
||||||
|
"""Test parsing text without any tool calls."""
|
||||||
|
test_text = "This is just normal text without any tool calls."
|
||||||
|
|
||||||
|
result = self.detector.detect_and_parse(test_text, self.tools)
|
||||||
|
|
||||||
|
self.assertEqual(len(result.calls), 0, "Should detect no tool calls")
|
||||||
|
self.assertEqual(
|
||||||
|
result.normal_text,
|
||||||
|
test_text,
|
||||||
|
"Should return the original text as normal text",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_detect_and_parse_with_text_before_tool_call(self):
|
||||||
|
"""Test parsing text that has content before the tool call."""
|
||||||
|
test_text = 'Here is some text before the tool call: [TOOL_CALLS] [{"name":"make_next_step_decision", "arguments":{"decision":"ANSWER", "content":"The answer is 42"}}]'
|
||||||
|
|
||||||
|
result = self.detector.detect_and_parse(test_text, self.tools)
|
||||||
|
|
||||||
|
self.assertEqual(len(result.calls), 1)
|
||||||
|
self.assertEqual(result.normal_text, "Here is some text before the tool call:")
|
||||||
|
|
||||||
|
call = result.calls[0]
|
||||||
|
self.assertEqual(call.name, "make_next_step_decision")
|
||||||
|
|
||||||
|
params = json.loads(call.parameters)
|
||||||
|
self.assertEqual(params["decision"], "ANSWER")
|
||||||
|
self.assertEqual(params["content"], "The answer is 42")
|
||||||
|
|
||||||
|
|
||||||
class TestEBNFGeneration(unittest.TestCase):
|
class TestEBNFGeneration(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
# Create sample tools for testing
|
# Create sample tools for testing
|
||||||
|
|||||||
Reference in New Issue
Block a user