import json import logging import re from typing import List from sglang.srt.entrypoints.openai.protocol import Tool from sglang.srt.function_call.base_format_detector import BaseFormatDetector from sglang.srt.function_call.core_types import ( StreamingParseResult, StructureInfo, _GetInfoFunc, ) from sglang.srt.function_call.ebnf_composer import EBNFComposer logger = logging.getLogger(__name__) class MistralDetector(BaseFormatDetector): """ Detector for Mistral model function call format. The Mistral format uses a simple bracket-delimited structure with JSON arrays containing function call objects. Format Structure: ``` [TOOL_CALLS] [{"name": "function_name", "arguments": {json_args}}, ...] ``` Reference: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3?chat_template=default """ def __init__(self): """ Initializes the detector with necessary state variables. """ super().__init__() self.bot_token = "[TOOL_CALLS] [" self.eot_token = "]" self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL) self.tool_call_separator = ", " def has_tool_call(self, text: str) -> bool: """Check if the text contains a Mistral format tool call.""" return self.bot_token in text def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: """ One-time parsing: Detects and parses tool calls in the provided text. :param text: The complete text to parse. :param tools: List of available tools. :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls. """ idx = text.find(self.bot_token) normal_text = text[:idx].strip() if idx != -1 else text if self.bot_token not in text: return StreamingParseResult(normal_text=normal_text, calls=[]) # Extract the JSON array part from [TOOL_CALLS] [...] # Use bracket counting to properly handle nested brackets in JSON content json_array_str = self._extract_json_array(text) if not json_array_str: return StreamingParseResult(normal_text=normal_text, calls=[]) calls = [] try: function_call_arr = json.loads(json_array_str) # Handle both single object and array of objects if not isinstance(function_call_arr, list): function_call_arr = [function_call_arr] calls = self.parse_base_json(function_call_arr, tools) except json.JSONDecodeError as e: logger.warning( f"Failed to parse JSON part: {json_array_str}, JSON parse error: {str(e)}" ) return StreamingParseResult(normal_text=normal_text, calls=calls) def _extract_json_array(self, text: str) -> str: """ Extract the JSON array part using bracket counting to handle nested brackets. :param text: The complete text containing [TOOL_CALLS] [...] :return: The JSON array string or None if not found """ start_idx = text.find(self.bot_token) if start_idx == -1: return None # Start from the opening bracket after [TOOL_CALLS] json_start = ( start_idx + len(self.bot_token) - 1 ) # -1 to include the opening bracket bracket_count = 0 in_string = False escape_next = False for i in range(json_start, len(text)): char = text[i] if escape_next: escape_next = False continue if char == "\\": escape_next = True continue if char == '"' and not escape_next: in_string = not in_string continue if not in_string: if char == "[": bracket_count += 1 elif char == "]": bracket_count -= 1 if bracket_count == 0: return text[json_start : i + 1] return None def structure_info(self) -> _GetInfoFunc: return lambda name: StructureInfo( begin='[TOOL_CALLS] [{"name":"' + name + '", "arguments":', end="}]", trigger="[TOOL_CALLS]", ) def build_ebnf(self, tools: List[Tool]): return EBNFComposer.build_ebnf( tools, sequence_start_token=self.bot_token, sequence_end_token=self.eot_token, function_format="json", tool_call_separator=self.tool_call_separator, )