diff --git a/python/sglang/srt/function_call/base_format_detector.py b/python/sglang/srt/function_call/base_format_detector.py
index 497e29c56..d5bb9dc89 100644
--- a/python/sglang/srt/function_call/base_format_detector.py
+++ b/python/sglang/srt/function_call/base_format_detector.py
@@ -72,20 +72,51 @@ class BaseFormatDetector(ABC):
action = json.loads(text)
return StreamingParseResult(calls=self.parse_base_json(action, tools))
+ def ends_with_partial_token(self, buffer: str, bot_token: str) -> int:
+ """
+ Check if buffer ends with a partial bot_token.
+ Return the length of the partial bot_token.
+
+ For some format, the bot_token is not a token in model's vocabulary, such as
+ `[TOOL_CALLS] [` in Mistral.
+ """
+ for i in range(1, min(len(buffer) + 1, len(bot_token))):
+ if bot_token.startswith(buffer[-i:]):
+ return i
+ return 0
+
def parse_streaming_increment(
self, new_text: str, tools: List[Tool]
) -> StreamingParseResult:
"""
Streaming incremental parsing with tool validation.
+
+ This base implementation works best with formats where:
+ 1. bot_token is followed immediately by JSON (e.g., bot_token + JSON_array)
+ 2. JSON can be parsed incrementally using partial_json_loads
+ 3. Multiple tool calls are separated by "; " or ", "
+
+ Examples of incompatible formats (need custom implementation, may reuse some logic from this class):
+ - Each tool call is wrapped in a separate block: See Qwen25Detector
+ - Multiple separate blocks: [TOOL_CALLS] [...] \n [TOOL_CALLS] [...]
+ - Tool call is Pythonic style
+
+ For incompatible formats, detectors should override this method with custom logic.
"""
# Append new text to buffer
self._buffer += new_text
current_text = self._buffer
if not (self.bot_token in current_text or current_text.startswith("{")):
- self._buffer = ""
- if self.eot_token in new_text:
- new_text = new_text.replace(self.eot_token, "")
- return StreamingParseResult(normal_text=new_text)
+ # Only clear buffer if we're sure no tool call is starting
+ if not self.ends_with_partial_token(self._buffer, self.bot_token):
+ normal_text = self._buffer
+ self._buffer = ""
+ if self.eot_token in normal_text:
+ normal_text = normal_text.replace(self.eot_token, "")
+ return StreamingParseResult(normal_text=normal_text)
+ else:
+ # Might be partial bot_token, keep buffering
+ return StreamingParseResult()
# Build tool indices if not already built
if not hasattr(self, "_tool_indices"):
diff --git a/python/sglang/srt/function_call/deepseekv3_detector.py b/python/sglang/srt/function_call/deepseekv3_detector.py
index b53c68ab6..cd84eace2 100644
--- a/python/sglang/srt/function_call/deepseekv3_detector.py
+++ b/python/sglang/srt/function_call/deepseekv3_detector.py
@@ -149,8 +149,8 @@ class DeepSeekV3Detector(BaseFormatDetector):
def build_ebnf(self, tools: List[Tool]):
return EBNFComposer.build_ebnf(
tools,
- bot_token=self.bot_token,
- eot_token=self.eot_token,
+ sequence_start_token=self.bot_token,
+ sequence_end_token=self.eot_token,
tool_call_separator="",
call_rule_fmt='"<|tool▁call▁begin|>function<|tool▁sep|>{name}\\n```json\\n" {arguments_rule} "\\n```<|tool▁call▁end|>"',
function_format="json",
diff --git a/python/sglang/srt/function_call/ebnf_composer.py b/python/sglang/srt/function_call/ebnf_composer.py
index d749f65d7..8cafd62b1 100644
--- a/python/sglang/srt/function_call/ebnf_composer.py
+++ b/python/sglang/srt/function_call/ebnf_composer.py
@@ -30,11 +30,6 @@ class EBNFComposer:
ws ::= [ \n\t]*
"""
- TOOL_CALLS_MAP = {
- "pythonic": '"[" function_call ("," function_call)* "]"',
- "json": "function_call",
- }
-
CALL_RULE_MAP = {
"pythonic": 'call_{name} ::= "{name}" "(" {arguments_rule} ")"',
"json": 'call_{name} ::= "{{" "\\"name\\"" ":" "\\"{name}\\"" ", " "\\"arguments\\"" ":" {arguments_rule} "}}"',
@@ -138,35 +133,54 @@ class EBNFComposer:
@staticmethod
def build_ebnf(
tools,
- *,
- call_rule_fmt: Optional[str] = None,
function_format: Literal["pythonic", "json"] = "json",
- bot_token: Optional[str] = None,
- eot_token: Optional[str] = None,
+ # Parameters for wrapping the entire sequence of tool calls
+ sequence_start_token: Optional[str] = None,
+ sequence_end_token: Optional[str] = None,
+ # Parameters for wrapping individual tool calls
+ individual_call_start_token: Optional[str] = None,
+ individual_call_end_token: Optional[str] = None,
+ # Parameter for separating multiple tool calls
tool_call_separator: Optional[str] = None,
+ call_rule_fmt: Optional[str] = None,
):
"""
Generalized EBNF builder for all detectors.
Args:
tools: List of Tool objects to generate EBNF grammar for
+ function_format: The format of function calls, either "pythonic" or "json"
+ sequence_start_token: Token that wraps the entire sequence of tool calls (start)
+ sequence_end_token: Token that wraps the entire sequence of tool calls (end)
+ individual_call_start_token: Token that wraps each individual tool call (start)
+ individual_call_end_token: Token that wraps each individual tool call (end)
+ tool_call_separator: The separator between multiple tool calls
call_rule_fmt: Optional custom format string for call_{name} rule. It should define each function call's format, with
the placeholders {name} for the function name and {arguments_rule} for the arguments rule. If None, a default
format based on function_format will be used.
- function_format: The format of function calls, either "pythonic" or "json"
- bot_token: The token that indicates the start of a tool call section
- eot_token: The token that indicates the end of a tool call section
- tool_call_separator: The separator between multiple tool calls
"""
# =================================================================
# Step 1: Determine the root tool calls rule
# =================================================================
- if bot_token and eot_token:
- if tool_call_separator:
- root_rule = f'"{bot_token}" function_call ( "{tool_call_separator}" function_call )* "{eot_token}"'
- else:
- root_rule = f'"{bot_token}" function_call "{eot_token}"'
+ # Handle a single function call
+ if individual_call_start_token and individual_call_end_token:
+ function_call_unit = f'"{individual_call_start_token}" function_call "{individual_call_end_token}"'
else:
- root_rule = EBNFComposer.TOOL_CALLS_MAP[function_format]
+ function_call_unit = "function_call"
+
+ # Handle multiple function calls with separators
+ if tool_call_separator is not None:
+ base_pattern = f'{function_call_unit} ( "{tool_call_separator}" {function_call_unit} )*'
+ else:
+ # Assume only support single function call
+ base_pattern = function_call_unit
+
+ # Apply sequence-level wrapping if needed
+ if sequence_start_token and sequence_end_token:
+ root_rule = (
+ f'"{sequence_start_token}" {base_pattern} "{sequence_end_token}"'
+ )
+ else:
+ root_rule = base_pattern
# =================================================================
# Step 2: Build the header rules
diff --git a/python/sglang/srt/function_call/mistral_detector.py b/python/sglang/srt/function_call/mistral_detector.py
index a5d2475ea..9e3260ffd 100644
--- a/python/sglang/srt/function_call/mistral_detector.py
+++ b/python/sglang/srt/function_call/mistral_detector.py
@@ -1,4 +1,5 @@
import json
+import logging
import re
from typing import List
@@ -11,12 +12,14 @@ from sglang.srt.function_call.core_types import (
from sglang.srt.function_call.ebnf_composer import EBNFComposer
from sglang.srt.openai_api.protocol import Tool
+logger = logging.getLogger(__name__)
+
class MistralDetector(BaseFormatDetector):
"""
Detector for Mistral models.
Assumes function call format:
- [TOOL_CALLS] [{"name":"xxx", "arguments":{...}}]
+ [TOOL_CALLS] [{"name":"func1", "arguments":{...}}, {"name":"func2", "arguments":{...}}]
"""
def __init__(self):
@@ -32,21 +35,6 @@ class MistralDetector(BaseFormatDetector):
"""Check if the text contains a Mistral format tool call."""
return self.bot_token in text
- def _clean_text(self, text: str) -> str:
- """
- clean text to only leave ''[TOOL_CALLS] [{"name": xxx, "arguments": {xxx}}]'
- for example,
- text = '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]\n\nToday\'s weather in Boston is :{function call result} (in Fahrenheit)\n\nIf you prefer Celsius, please let me know.'
- return '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]'
- The key pattern is [TOOL_CALLS] [...]
- """
- # TODO: check if Mistral supports multiple tool calls, currently assume only support one tool call
- find_results = re.findall(r"\[TOOL_CALLS\] \[.*?\]", text, re.DOTALL)
- if len(find_results) > 0:
- return find_results[0]
- else:
- return ""
-
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
"""
One-time parsing: Detects and parses tool calls in the provided text.
@@ -57,17 +45,74 @@ class MistralDetector(BaseFormatDetector):
"""
idx = text.find(self.bot_token)
normal_text = text[:idx].strip() if idx != -1 else text
- text = self._clean_text(text)
- tool_content = text.replace("[TOOL_CALLS]", "").strip()
- raw_tool_calls = self.tool_call_regex.findall(tool_content)
+
+ if self.bot_token not in text:
+ return StreamingParseResult(normal_text=normal_text, calls=[])
+
+ # Extract the JSON array part from [TOOL_CALLS] [...]
+ # Use bracket counting to properly handle nested brackets in JSON content
+ json_array_str = self._extract_json_array(text)
+ if not json_array_str:
+ return StreamingParseResult(normal_text=normal_text, calls=[])
+
calls = []
- if len(raw_tool_calls) > 0:
- raw_tool_call = raw_tool_calls[0]
- function_call_arr = json.loads(raw_tool_call)
- for match_result in function_call_arr:
- calls.extend(self.parse_base_json(match_result, tools))
+ try:
+ function_call_arr = json.loads(json_array_str)
+ # Handle both single object and array of objects
+ if not isinstance(function_call_arr, list):
+ function_call_arr = [function_call_arr]
+ calls = self.parse_base_json(function_call_arr, tools)
+ except json.JSONDecodeError as e:
+ logger.warning(
+ f"Failed to parse JSON part: {json_array_str}, JSON parse error: {str(e)}"
+ )
+
return StreamingParseResult(normal_text=normal_text, calls=calls)
+ def _extract_json_array(self, text: str) -> str:
+ """
+ Extract the JSON array part using bracket counting to handle nested brackets.
+
+ :param text: The complete text containing [TOOL_CALLS] [...]
+ :return: The JSON array string or None if not found
+ """
+ start_idx = text.find(self.bot_token)
+ if start_idx == -1:
+ return None
+
+ # Start from the opening bracket after [TOOL_CALLS]
+ json_start = (
+ start_idx + len(self.bot_token) - 1
+ ) # -1 to include the opening bracket
+ bracket_count = 0
+ in_string = False
+ escape_next = False
+
+ for i in range(json_start, len(text)):
+ char = text[i]
+
+ if escape_next:
+ escape_next = False
+ continue
+
+ if char == "\\":
+ escape_next = True
+ continue
+
+ if char == '"' and not escape_next:
+ in_string = not in_string
+ continue
+
+ if not in_string:
+ if char == "[":
+ bracket_count += 1
+ elif char == "]":
+ bracket_count -= 1
+ if bracket_count == 0:
+ return text[json_start : i + 1]
+
+ return None
+
def structure_info(self) -> _GetInfoFunc:
return lambda name: StructureInfo(
begin='[TOOL_CALLS] [{"name":"' + name + '", "arguments":',
@@ -78,7 +123,8 @@ class MistralDetector(BaseFormatDetector):
def build_ebnf(self, tools: List[Tool]):
return EBNFComposer.build_ebnf(
tools,
- bot_token=self.bot_token,
- eot_token=self.eot_token,
+ sequence_start_token=self.bot_token,
+ sequence_end_token=self.eot_token,
function_format="json",
+ tool_call_separator=", ",
)
diff --git a/python/sglang/srt/function_call/pythonic_detector.py b/python/sglang/srt/function_call/pythonic_detector.py
index e60ab63bf..2ee802284 100644
--- a/python/sglang/srt/function_call/pythonic_detector.py
+++ b/python/sglang/srt/function_call/pythonic_detector.py
@@ -156,8 +156,8 @@ class PythonicDetector(BaseFormatDetector):
def build_ebnf(self, tools: List[Tool]) -> Optional[str]:
return EBNFComposer.build_ebnf(
tools,
- bot_token="[",
- eot_token="]",
+ sequence_start_token="[",
+ sequence_end_token="]",
tool_call_separator=",",
function_format="pythonic",
)
diff --git a/python/sglang/srt/function_call/qwen25_detector.py b/python/sglang/srt/function_call/qwen25_detector.py
index 1d32099f7..c43ea384f 100644
--- a/python/sglang/srt/function_call/qwen25_detector.py
+++ b/python/sglang/srt/function_call/qwen25_detector.py
@@ -1,4 +1,5 @@
import json
+import logging
import re
from typing import List
@@ -11,12 +12,14 @@ from sglang.srt.function_call.core_types import (
from sglang.srt.function_call.ebnf_composer import EBNFComposer
from sglang.srt.openai_api.protocol import Tool
+logger = logging.getLogger(__name__)
+
class Qwen25Detector(BaseFormatDetector):
"""
Detector for Qwen 2.5 models.
Assumes function call format:
- {"name":"xxx", "arguments":{...}}
+ \n{"name":"func1", "arguments":{...}}\n\n\n{"name":"func2", "arguments":{...}}\n
"""
def __init__(self):
@@ -24,8 +27,9 @@ class Qwen25Detector(BaseFormatDetector):
Initializes the detector with necessary state variables.
"""
super().__init__()
- self.bot_token = ""
- self.eot_token = ""
+ self.bot_token = "\n"
+ self.eot_token = "\n"
+ self._normal_text_buffer = "" # Buffer for handling partial end tokens
def has_tool_call(self, text: str) -> bool:
"""Check if the text contains a Qwen 2.5 format tool call."""
@@ -43,15 +47,64 @@ class Qwen25Detector(BaseFormatDetector):
normal_text = text[:idx].strip() if idx != -1 else text
if self.bot_token not in text:
return StreamingParseResult(normal_text=normal_text, calls=[])
- pattern = rf"{self.bot_token}(.*?){self.eot_token}"
+
+ # Find all \n...\n blocks
+ pattern = rf"{re.escape(self.bot_token)}(.*?){re.escape(self.eot_token)}"
match_result_list = re.findall(pattern, text, re.DOTALL)
calls = []
for match_result in match_result_list:
- match_result = json.loads(match_result)
- calls.extend(self.parse_base_json(match_result, tools))
+ try:
+ parsed_call = json.loads(match_result.strip())
+ calls.extend(self.parse_base_json(parsed_call, tools))
+ except json.JSONDecodeError as e:
+ logger.warning(
+ f"Failed to parse JSON part: {match_result}, JSON parse error: {str(e)}"
+ )
+ continue
return StreamingParseResult(normal_text=normal_text, calls=calls)
+ def parse_streaming_increment(
+ self, new_text: str, tools: List[Tool]
+ ) -> StreamingParseResult:
+ """
+ Streaming incremental parsing for Qwen 2.5 tool calls.
+ Uses base class implementation with buffering to handle partial end tokens.
+ """
+ result = super().parse_streaming_increment(new_text, tools)
+
+ # Handle partial end tokens that are streamed character by character
+ if result.normal_text:
+ self._normal_text_buffer += result.normal_text
+
+ # Check if buffer contains complete end token (without leading newline)
+ end_token_without_newline = self.eot_token[1:] # ""
+ if end_token_without_newline in self._normal_text_buffer:
+ cleaned_text = self._normal_text_buffer.replace(
+ end_token_without_newline, ""
+ )
+ self._normal_text_buffer = ""
+ result.normal_text = cleaned_text
+ else:
+ # Check if buffer might contain partial end token at the end
+ partial_match_len = self.ends_with_partial_token(
+ self._normal_text_buffer, end_token_without_newline
+ )
+
+ if partial_match_len:
+ # Keep potential partial match in buffer, return the rest
+ result.normal_text = self._normal_text_buffer[:-partial_match_len]
+ self._normal_text_buffer = self._normal_text_buffer[
+ -partial_match_len:
+ ]
+ else:
+ # No partial match, return all buffered text
+ result.normal_text = self._normal_text_buffer
+ self._normal_text_buffer = ""
+
+ return result
+
def structure_info(self) -> _GetInfoFunc:
+ # TODO: Update the begin and end tokens with '\n' if necessary
return lambda name: StructureInfo(
begin='{"name":"' + name + '", "arguments":',
end="}",
@@ -61,7 +114,8 @@ class Qwen25Detector(BaseFormatDetector):
def build_ebnf(self, tools: List[Tool]):
return EBNFComposer.build_ebnf(
tools,
- bot_token=self.bot_token,
- eot_token=self.eot_token,
+ individual_call_start_token=self.bot_token.replace("\n", "\\n"),
+ individual_call_end_token=self.eot_token.replace("\n", "\\n"),
+ tool_call_separator="\\n",
function_format="json",
)
diff --git a/test/srt/test_function_call_parser.py b/test/srt/test_function_call_parser.py
index 0a00a7dbd..99c7c9dd7 100644
--- a/test/srt/test_function_call_parser.py
+++ b/test/srt/test_function_call_parser.py
@@ -265,6 +265,118 @@ class TestPythonicDetector(unittest.TestCase):
self.assertEqual(params["data"], [1, 2, 3])
+class TestMistralDetector(unittest.TestCase):
+ def setUp(self):
+ """Set up test tools and detector for Mistral format testing."""
+ self.tools = [
+ Tool(
+ type="function",
+ function=Function(
+ name="make_next_step_decision",
+ description="Test function for decision making",
+ parameters={
+ "type": "object",
+ "properties": {
+ "decision": {
+ "type": "string",
+ "description": "The next step to take",
+ },
+ "content": {
+ "type": "string",
+ "description": "The content of the next step",
+ },
+ },
+ "required": ["decision", "content"],
+ },
+ ),
+ ),
+ ]
+ self.detector = MistralDetector()
+
+ def test_detect_and_parse_with_nested_brackets_in_content(self):
+ """Test parsing Mistral format with nested brackets in JSON content.
+
+ This test case specifically addresses the issue where the regex pattern
+ was incorrectly truncating JSON when it contained nested brackets like [City Name].
+ """
+ # This is the exact problematic text from the original test failure
+ test_text = '[TOOL_CALLS] [{"name":"make_next_step_decision", "arguments":{"decision":"","content":"```\\nTOOL: Access a weather API or service\\nOBSERVATION: Retrieve the current weather data for the top 5 populated cities in the US\\nANSWER: The weather in the top 5 populated cities in the US is as follows: [City Name] - [Weather Conditions] - [Temperature]\\n```"}}]'
+
+ result = self.detector.detect_and_parse(test_text, self.tools)
+
+ # Verify that the parsing was successful
+ self.assertEqual(len(result.calls), 1, "Should detect exactly one tool call")
+
+ call = result.calls[0]
+ self.assertEqual(
+ call.name,
+ "make_next_step_decision",
+ "Should detect the correct function name",
+ )
+
+ # Verify that the parameters are valid JSON and contain the expected content
+ params = json.loads(call.parameters)
+ self.assertEqual(
+ params["decision"], "", "Decision parameter should be empty string"
+ )
+
+ # The content should contain the full text including the nested brackets [City Name]
+ expected_content = "```\nTOOL: Access a weather API or service\nOBSERVATION: Retrieve the current weather data for the top 5 populated cities in the US\nANSWER: The weather in the top 5 populated cities in the US is as follows: [City Name] - [Weather Conditions] - [Temperature]\n```"
+ self.assertEqual(
+ params["content"],
+ expected_content,
+ "Content should include nested brackets without truncation",
+ )
+
+ # Verify that normal text is empty (since the entire input is a tool call)
+ self.assertEqual(
+ result.normal_text, "", "Normal text should be empty for pure tool call"
+ )
+
+ def test_detect_and_parse_simple_case(self):
+ """Test parsing a simple Mistral format tool call without nested brackets."""
+ test_text = '[TOOL_CALLS] [{"name":"make_next_step_decision", "arguments":{"decision":"TOOL", "content":"Use weather API"}}]'
+
+ result = self.detector.detect_and_parse(test_text, self.tools)
+
+ self.assertEqual(len(result.calls), 1)
+ call = result.calls[0]
+ self.assertEqual(call.name, "make_next_step_decision")
+
+ params = json.loads(call.parameters)
+ self.assertEqual(params["decision"], "TOOL")
+ self.assertEqual(params["content"], "Use weather API")
+
+ def test_detect_and_parse_no_tool_calls(self):
+ """Test parsing text without any tool calls."""
+ test_text = "This is just normal text without any tool calls."
+
+ result = self.detector.detect_and_parse(test_text, self.tools)
+
+ self.assertEqual(len(result.calls), 0, "Should detect no tool calls")
+ self.assertEqual(
+ result.normal_text,
+ test_text,
+ "Should return the original text as normal text",
+ )
+
+ def test_detect_and_parse_with_text_before_tool_call(self):
+ """Test parsing text that has content before the tool call."""
+ test_text = 'Here is some text before the tool call: [TOOL_CALLS] [{"name":"make_next_step_decision", "arguments":{"decision":"ANSWER", "content":"The answer is 42"}}]'
+
+ result = self.detector.detect_and_parse(test_text, self.tools)
+
+ self.assertEqual(len(result.calls), 1)
+ self.assertEqual(result.normal_text, "Here is some text before the tool call:")
+
+ call = result.calls[0]
+ self.assertEqual(call.name, "make_next_step_decision")
+
+ params = json.loads(call.parameters)
+ self.assertEqual(params["decision"], "ANSWER")
+ self.assertEqual(params["content"], "The answer is 42")
+
+
class TestEBNFGeneration(unittest.TestCase):
def setUp(self):
# Create sample tools for testing