diff --git a/python/sglang/srt/function_call/base_format_detector.py b/python/sglang/srt/function_call/base_format_detector.py index 497e29c56..d5bb9dc89 100644 --- a/python/sglang/srt/function_call/base_format_detector.py +++ b/python/sglang/srt/function_call/base_format_detector.py @@ -72,20 +72,51 @@ class BaseFormatDetector(ABC): action = json.loads(text) return StreamingParseResult(calls=self.parse_base_json(action, tools)) + def ends_with_partial_token(self, buffer: str, bot_token: str) -> int: + """ + Check if buffer ends with a partial bot_token. + Return the length of the partial bot_token. + + For some format, the bot_token is not a token in model's vocabulary, such as + `[TOOL_CALLS] [` in Mistral. + """ + for i in range(1, min(len(buffer) + 1, len(bot_token))): + if bot_token.startswith(buffer[-i:]): + return i + return 0 + def parse_streaming_increment( self, new_text: str, tools: List[Tool] ) -> StreamingParseResult: """ Streaming incremental parsing with tool validation. + + This base implementation works best with formats where: + 1. bot_token is followed immediately by JSON (e.g., bot_token + JSON_array) + 2. JSON can be parsed incrementally using partial_json_loads + 3. Multiple tool calls are separated by "; " or ", " + + Examples of incompatible formats (need custom implementation, may reuse some logic from this class): + - Each tool call is wrapped in a separate block: See Qwen25Detector + - Multiple separate blocks: [TOOL_CALLS] [...] \n [TOOL_CALLS] [...] + - Tool call is Pythonic style + + For incompatible formats, detectors should override this method with custom logic. """ # Append new text to buffer self._buffer += new_text current_text = self._buffer if not (self.bot_token in current_text or current_text.startswith("{")): - self._buffer = "" - if self.eot_token in new_text: - new_text = new_text.replace(self.eot_token, "") - return StreamingParseResult(normal_text=new_text) + # Only clear buffer if we're sure no tool call is starting + if not self.ends_with_partial_token(self._buffer, self.bot_token): + normal_text = self._buffer + self._buffer = "" + if self.eot_token in normal_text: + normal_text = normal_text.replace(self.eot_token, "") + return StreamingParseResult(normal_text=normal_text) + else: + # Might be partial bot_token, keep buffering + return StreamingParseResult() # Build tool indices if not already built if not hasattr(self, "_tool_indices"): diff --git a/python/sglang/srt/function_call/deepseekv3_detector.py b/python/sglang/srt/function_call/deepseekv3_detector.py index b53c68ab6..cd84eace2 100644 --- a/python/sglang/srt/function_call/deepseekv3_detector.py +++ b/python/sglang/srt/function_call/deepseekv3_detector.py @@ -149,8 +149,8 @@ class DeepSeekV3Detector(BaseFormatDetector): def build_ebnf(self, tools: List[Tool]): return EBNFComposer.build_ebnf( tools, - bot_token=self.bot_token, - eot_token=self.eot_token, + sequence_start_token=self.bot_token, + sequence_end_token=self.eot_token, tool_call_separator="", call_rule_fmt='"<|tool▁call▁begin|>function<|tool▁sep|>{name}\\n```json\\n" {arguments_rule} "\\n```<|tool▁call▁end|>"', function_format="json", diff --git a/python/sglang/srt/function_call/ebnf_composer.py b/python/sglang/srt/function_call/ebnf_composer.py index d749f65d7..8cafd62b1 100644 --- a/python/sglang/srt/function_call/ebnf_composer.py +++ b/python/sglang/srt/function_call/ebnf_composer.py @@ -30,11 +30,6 @@ class EBNFComposer: ws ::= [ \n\t]* """ - TOOL_CALLS_MAP = { - "pythonic": '"[" function_call ("," function_call)* "]"', - "json": "function_call", - } - CALL_RULE_MAP = { "pythonic": 'call_{name} ::= "{name}" "(" {arguments_rule} ")"', "json": 'call_{name} ::= "{{" "\\"name\\"" ":" "\\"{name}\\"" ", " "\\"arguments\\"" ":" {arguments_rule} "}}"', @@ -138,35 +133,54 @@ class EBNFComposer: @staticmethod def build_ebnf( tools, - *, - call_rule_fmt: Optional[str] = None, function_format: Literal["pythonic", "json"] = "json", - bot_token: Optional[str] = None, - eot_token: Optional[str] = None, + # Parameters for wrapping the entire sequence of tool calls + sequence_start_token: Optional[str] = None, + sequence_end_token: Optional[str] = None, + # Parameters for wrapping individual tool calls + individual_call_start_token: Optional[str] = None, + individual_call_end_token: Optional[str] = None, + # Parameter for separating multiple tool calls tool_call_separator: Optional[str] = None, + call_rule_fmt: Optional[str] = None, ): """ Generalized EBNF builder for all detectors. Args: tools: List of Tool objects to generate EBNF grammar for + function_format: The format of function calls, either "pythonic" or "json" + sequence_start_token: Token that wraps the entire sequence of tool calls (start) + sequence_end_token: Token that wraps the entire sequence of tool calls (end) + individual_call_start_token: Token that wraps each individual tool call (start) + individual_call_end_token: Token that wraps each individual tool call (end) + tool_call_separator: The separator between multiple tool calls call_rule_fmt: Optional custom format string for call_{name} rule. It should define each function call's format, with the placeholders {name} for the function name and {arguments_rule} for the arguments rule. If None, a default format based on function_format will be used. - function_format: The format of function calls, either "pythonic" or "json" - bot_token: The token that indicates the start of a tool call section - eot_token: The token that indicates the end of a tool call section - tool_call_separator: The separator between multiple tool calls """ # ================================================================= # Step 1: Determine the root tool calls rule # ================================================================= - if bot_token and eot_token: - if tool_call_separator: - root_rule = f'"{bot_token}" function_call ( "{tool_call_separator}" function_call )* "{eot_token}"' - else: - root_rule = f'"{bot_token}" function_call "{eot_token}"' + # Handle a single function call + if individual_call_start_token and individual_call_end_token: + function_call_unit = f'"{individual_call_start_token}" function_call "{individual_call_end_token}"' else: - root_rule = EBNFComposer.TOOL_CALLS_MAP[function_format] + function_call_unit = "function_call" + + # Handle multiple function calls with separators + if tool_call_separator is not None: + base_pattern = f'{function_call_unit} ( "{tool_call_separator}" {function_call_unit} )*' + else: + # Assume only support single function call + base_pattern = function_call_unit + + # Apply sequence-level wrapping if needed + if sequence_start_token and sequence_end_token: + root_rule = ( + f'"{sequence_start_token}" {base_pattern} "{sequence_end_token}"' + ) + else: + root_rule = base_pattern # ================================================================= # Step 2: Build the header rules diff --git a/python/sglang/srt/function_call/mistral_detector.py b/python/sglang/srt/function_call/mistral_detector.py index a5d2475ea..9e3260ffd 100644 --- a/python/sglang/srt/function_call/mistral_detector.py +++ b/python/sglang/srt/function_call/mistral_detector.py @@ -1,4 +1,5 @@ import json +import logging import re from typing import List @@ -11,12 +12,14 @@ from sglang.srt.function_call.core_types import ( from sglang.srt.function_call.ebnf_composer import EBNFComposer from sglang.srt.openai_api.protocol import Tool +logger = logging.getLogger(__name__) + class MistralDetector(BaseFormatDetector): """ Detector for Mistral models. Assumes function call format: - [TOOL_CALLS] [{"name":"xxx", "arguments":{...}}] + [TOOL_CALLS] [{"name":"func1", "arguments":{...}}, {"name":"func2", "arguments":{...}}] """ def __init__(self): @@ -32,21 +35,6 @@ class MistralDetector(BaseFormatDetector): """Check if the text contains a Mistral format tool call.""" return self.bot_token in text - def _clean_text(self, text: str) -> str: - """ - clean text to only leave ''[TOOL_CALLS] [{"name": xxx, "arguments": {xxx}}]' - for example, - text = '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]\n\nToday\'s weather in Boston is :{function call result} (in Fahrenheit)\n\nIf you prefer Celsius, please let me know.' - return '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]' - The key pattern is [TOOL_CALLS] [...] - """ - # TODO: check if Mistral supports multiple tool calls, currently assume only support one tool call - find_results = re.findall(r"\[TOOL_CALLS\] \[.*?\]", text, re.DOTALL) - if len(find_results) > 0: - return find_results[0] - else: - return "" - def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: """ One-time parsing: Detects and parses tool calls in the provided text. @@ -57,17 +45,74 @@ class MistralDetector(BaseFormatDetector): """ idx = text.find(self.bot_token) normal_text = text[:idx].strip() if idx != -1 else text - text = self._clean_text(text) - tool_content = text.replace("[TOOL_CALLS]", "").strip() - raw_tool_calls = self.tool_call_regex.findall(tool_content) + + if self.bot_token not in text: + return StreamingParseResult(normal_text=normal_text, calls=[]) + + # Extract the JSON array part from [TOOL_CALLS] [...] + # Use bracket counting to properly handle nested brackets in JSON content + json_array_str = self._extract_json_array(text) + if not json_array_str: + return StreamingParseResult(normal_text=normal_text, calls=[]) + calls = [] - if len(raw_tool_calls) > 0: - raw_tool_call = raw_tool_calls[0] - function_call_arr = json.loads(raw_tool_call) - for match_result in function_call_arr: - calls.extend(self.parse_base_json(match_result, tools)) + try: + function_call_arr = json.loads(json_array_str) + # Handle both single object and array of objects + if not isinstance(function_call_arr, list): + function_call_arr = [function_call_arr] + calls = self.parse_base_json(function_call_arr, tools) + except json.JSONDecodeError as e: + logger.warning( + f"Failed to parse JSON part: {json_array_str}, JSON parse error: {str(e)}" + ) + return StreamingParseResult(normal_text=normal_text, calls=calls) + def _extract_json_array(self, text: str) -> str: + """ + Extract the JSON array part using bracket counting to handle nested brackets. + + :param text: The complete text containing [TOOL_CALLS] [...] + :return: The JSON array string or None if not found + """ + start_idx = text.find(self.bot_token) + if start_idx == -1: + return None + + # Start from the opening bracket after [TOOL_CALLS] + json_start = ( + start_idx + len(self.bot_token) - 1 + ) # -1 to include the opening bracket + bracket_count = 0 + in_string = False + escape_next = False + + for i in range(json_start, len(text)): + char = text[i] + + if escape_next: + escape_next = False + continue + + if char == "\\": + escape_next = True + continue + + if char == '"' and not escape_next: + in_string = not in_string + continue + + if not in_string: + if char == "[": + bracket_count += 1 + elif char == "]": + bracket_count -= 1 + if bracket_count == 0: + return text[json_start : i + 1] + + return None + def structure_info(self) -> _GetInfoFunc: return lambda name: StructureInfo( begin='[TOOL_CALLS] [{"name":"' + name + '", "arguments":', @@ -78,7 +123,8 @@ class MistralDetector(BaseFormatDetector): def build_ebnf(self, tools: List[Tool]): return EBNFComposer.build_ebnf( tools, - bot_token=self.bot_token, - eot_token=self.eot_token, + sequence_start_token=self.bot_token, + sequence_end_token=self.eot_token, function_format="json", + tool_call_separator=", ", ) diff --git a/python/sglang/srt/function_call/pythonic_detector.py b/python/sglang/srt/function_call/pythonic_detector.py index e60ab63bf..2ee802284 100644 --- a/python/sglang/srt/function_call/pythonic_detector.py +++ b/python/sglang/srt/function_call/pythonic_detector.py @@ -156,8 +156,8 @@ class PythonicDetector(BaseFormatDetector): def build_ebnf(self, tools: List[Tool]) -> Optional[str]: return EBNFComposer.build_ebnf( tools, - bot_token="[", - eot_token="]", + sequence_start_token="[", + sequence_end_token="]", tool_call_separator=",", function_format="pythonic", ) diff --git a/python/sglang/srt/function_call/qwen25_detector.py b/python/sglang/srt/function_call/qwen25_detector.py index 1d32099f7..c43ea384f 100644 --- a/python/sglang/srt/function_call/qwen25_detector.py +++ b/python/sglang/srt/function_call/qwen25_detector.py @@ -1,4 +1,5 @@ import json +import logging import re from typing import List @@ -11,12 +12,14 @@ from sglang.srt.function_call.core_types import ( from sglang.srt.function_call.ebnf_composer import EBNFComposer from sglang.srt.openai_api.protocol import Tool +logger = logging.getLogger(__name__) + class Qwen25Detector(BaseFormatDetector): """ Detector for Qwen 2.5 models. Assumes function call format: - {"name":"xxx", "arguments":{...}} + \n{"name":"func1", "arguments":{...}}\n\n\n{"name":"func2", "arguments":{...}}\n """ def __init__(self): @@ -24,8 +27,9 @@ class Qwen25Detector(BaseFormatDetector): Initializes the detector with necessary state variables. """ super().__init__() - self.bot_token = "" - self.eot_token = "" + self.bot_token = "\n" + self.eot_token = "\n" + self._normal_text_buffer = "" # Buffer for handling partial end tokens def has_tool_call(self, text: str) -> bool: """Check if the text contains a Qwen 2.5 format tool call.""" @@ -43,15 +47,64 @@ class Qwen25Detector(BaseFormatDetector): normal_text = text[:idx].strip() if idx != -1 else text if self.bot_token not in text: return StreamingParseResult(normal_text=normal_text, calls=[]) - pattern = rf"{self.bot_token}(.*?){self.eot_token}" + + # Find all \n...\n blocks + pattern = rf"{re.escape(self.bot_token)}(.*?){re.escape(self.eot_token)}" match_result_list = re.findall(pattern, text, re.DOTALL) calls = [] for match_result in match_result_list: - match_result = json.loads(match_result) - calls.extend(self.parse_base_json(match_result, tools)) + try: + parsed_call = json.loads(match_result.strip()) + calls.extend(self.parse_base_json(parsed_call, tools)) + except json.JSONDecodeError as e: + logger.warning( + f"Failed to parse JSON part: {match_result}, JSON parse error: {str(e)}" + ) + continue return StreamingParseResult(normal_text=normal_text, calls=calls) + def parse_streaming_increment( + self, new_text: str, tools: List[Tool] + ) -> StreamingParseResult: + """ + Streaming incremental parsing for Qwen 2.5 tool calls. + Uses base class implementation with buffering to handle partial end tokens. + """ + result = super().parse_streaming_increment(new_text, tools) + + # Handle partial end tokens that are streamed character by character + if result.normal_text: + self._normal_text_buffer += result.normal_text + + # Check if buffer contains complete end token (without leading newline) + end_token_without_newline = self.eot_token[1:] # "" + if end_token_without_newline in self._normal_text_buffer: + cleaned_text = self._normal_text_buffer.replace( + end_token_without_newline, "" + ) + self._normal_text_buffer = "" + result.normal_text = cleaned_text + else: + # Check if buffer might contain partial end token at the end + partial_match_len = self.ends_with_partial_token( + self._normal_text_buffer, end_token_without_newline + ) + + if partial_match_len: + # Keep potential partial match in buffer, return the rest + result.normal_text = self._normal_text_buffer[:-partial_match_len] + self._normal_text_buffer = self._normal_text_buffer[ + -partial_match_len: + ] + else: + # No partial match, return all buffered text + result.normal_text = self._normal_text_buffer + self._normal_text_buffer = "" + + return result + def structure_info(self) -> _GetInfoFunc: + # TODO: Update the begin and end tokens with '\n' if necessary return lambda name: StructureInfo( begin='{"name":"' + name + '", "arguments":', end="}", @@ -61,7 +114,8 @@ class Qwen25Detector(BaseFormatDetector): def build_ebnf(self, tools: List[Tool]): return EBNFComposer.build_ebnf( tools, - bot_token=self.bot_token, - eot_token=self.eot_token, + individual_call_start_token=self.bot_token.replace("\n", "\\n"), + individual_call_end_token=self.eot_token.replace("\n", "\\n"), + tool_call_separator="\\n", function_format="json", ) diff --git a/test/srt/test_function_call_parser.py b/test/srt/test_function_call_parser.py index 0a00a7dbd..99c7c9dd7 100644 --- a/test/srt/test_function_call_parser.py +++ b/test/srt/test_function_call_parser.py @@ -265,6 +265,118 @@ class TestPythonicDetector(unittest.TestCase): self.assertEqual(params["data"], [1, 2, 3]) +class TestMistralDetector(unittest.TestCase): + def setUp(self): + """Set up test tools and detector for Mistral format testing.""" + self.tools = [ + Tool( + type="function", + function=Function( + name="make_next_step_decision", + description="Test function for decision making", + parameters={ + "type": "object", + "properties": { + "decision": { + "type": "string", + "description": "The next step to take", + }, + "content": { + "type": "string", + "description": "The content of the next step", + }, + }, + "required": ["decision", "content"], + }, + ), + ), + ] + self.detector = MistralDetector() + + def test_detect_and_parse_with_nested_brackets_in_content(self): + """Test parsing Mistral format with nested brackets in JSON content. + + This test case specifically addresses the issue where the regex pattern + was incorrectly truncating JSON when it contained nested brackets like [City Name]. + """ + # This is the exact problematic text from the original test failure + test_text = '[TOOL_CALLS] [{"name":"make_next_step_decision", "arguments":{"decision":"","content":"```\\nTOOL: Access a weather API or service\\nOBSERVATION: Retrieve the current weather data for the top 5 populated cities in the US\\nANSWER: The weather in the top 5 populated cities in the US is as follows: [City Name] - [Weather Conditions] - [Temperature]\\n```"}}]' + + result = self.detector.detect_and_parse(test_text, self.tools) + + # Verify that the parsing was successful + self.assertEqual(len(result.calls), 1, "Should detect exactly one tool call") + + call = result.calls[0] + self.assertEqual( + call.name, + "make_next_step_decision", + "Should detect the correct function name", + ) + + # Verify that the parameters are valid JSON and contain the expected content + params = json.loads(call.parameters) + self.assertEqual( + params["decision"], "", "Decision parameter should be empty string" + ) + + # The content should contain the full text including the nested brackets [City Name] + expected_content = "```\nTOOL: Access a weather API or service\nOBSERVATION: Retrieve the current weather data for the top 5 populated cities in the US\nANSWER: The weather in the top 5 populated cities in the US is as follows: [City Name] - [Weather Conditions] - [Temperature]\n```" + self.assertEqual( + params["content"], + expected_content, + "Content should include nested brackets without truncation", + ) + + # Verify that normal text is empty (since the entire input is a tool call) + self.assertEqual( + result.normal_text, "", "Normal text should be empty for pure tool call" + ) + + def test_detect_and_parse_simple_case(self): + """Test parsing a simple Mistral format tool call without nested brackets.""" + test_text = '[TOOL_CALLS] [{"name":"make_next_step_decision", "arguments":{"decision":"TOOL", "content":"Use weather API"}}]' + + result = self.detector.detect_and_parse(test_text, self.tools) + + self.assertEqual(len(result.calls), 1) + call = result.calls[0] + self.assertEqual(call.name, "make_next_step_decision") + + params = json.loads(call.parameters) + self.assertEqual(params["decision"], "TOOL") + self.assertEqual(params["content"], "Use weather API") + + def test_detect_and_parse_no_tool_calls(self): + """Test parsing text without any tool calls.""" + test_text = "This is just normal text without any tool calls." + + result = self.detector.detect_and_parse(test_text, self.tools) + + self.assertEqual(len(result.calls), 0, "Should detect no tool calls") + self.assertEqual( + result.normal_text, + test_text, + "Should return the original text as normal text", + ) + + def test_detect_and_parse_with_text_before_tool_call(self): + """Test parsing text that has content before the tool call.""" + test_text = 'Here is some text before the tool call: [TOOL_CALLS] [{"name":"make_next_step_decision", "arguments":{"decision":"ANSWER", "content":"The answer is 42"}}]' + + result = self.detector.detect_and_parse(test_text, self.tools) + + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.normal_text, "Here is some text before the tool call:") + + call = result.calls[0] + self.assertEqual(call.name, "make_next_step_decision") + + params = json.loads(call.parameters) + self.assertEqual(params["decision"], "ANSWER") + self.assertEqual(params["content"], "The answer is 42") + + class TestEBNFGeneration(unittest.TestCase): def setUp(self): # Create sample tools for testing