# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import json from collections.abc import Sequence from typing import Any import regex as re from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import ( ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, ExtractedToolCallInformation, FunctionCall, ToolCall, ) from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike from vllm.tool_parsers.abstract_tool_parser import ( ToolParser, ) from vllm.tool_parsers.utils import extract_intermediate_diff logger = init_logger(__name__) class MinimaxToolParser(ToolParser): def __init__(self, tokenizer: TokenizerLike): super().__init__(tokenizer) # Initialize streaming state for tracking tool call progress self.streaming_state: dict[str, Any] = { "current_tool_index": -1, # Index of current tool being processed "tool_ids": [], # List of tool call IDs "sent_tools": [], # List of tools that have been sent } # Define tool call tokens and patterns self.tool_call_start_token = "" self.tool_call_end_token = "" self.tool_call_regex = re.compile( r"(.*?)|(.*)", re.DOTALL ) self.thinking_tag_pattern = r"(.*?)" self.tool_name_pattern = re.compile(r'"name":\s*"([^"]+)"') self.tool_args_pattern = re.compile(r'"arguments":\s*') # Buffer for handling partial tool calls during streaming self.pending_buffer = "" self.in_thinking_tag = False if not self.model_tokenizer: raise ValueError( "The model tokenizer must be passed to the ToolParser " "constructor during construction." ) # Get token IDs for tool call start/end tokens self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token) self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None: logger.warning( "Minimax Tool parser could not locate tool call start/end " "tokens in the tokenizer. Falling back to string matching." ) def preprocess_model_output(self, model_output: str) -> str: """ Preprocess model output by removing tool calls from thinking tags. Args: model_output: Raw model output string Returns: Preprocessed model output with tool calls removed from thinking tags """ def remove_tool_calls_from_think(match): think_content = match.group(1) cleaned_content = re.sub( r".*?", "", think_content, flags=re.DOTALL ) return f"{cleaned_content}" return re.sub( self.thinking_tag_pattern, remove_tool_calls_from_think, model_output, flags=re.DOTALL, ) def _clean_duplicate_braces(self, args_text: str) -> str: """ Clean duplicate closing braces from arguments text. Args: args_text: Raw arguments text Returns: Cleaned arguments text with proper JSON formatting """ args_text = args_text.strip() if not args_text: return args_text try: json.loads(args_text) return args_text except json.JSONDecodeError: pass while args_text.endswith("}}"): candidate = args_text[:-1] try: json.loads(candidate) return candidate except json.JSONDecodeError: args_text = candidate return args_text def _clean_delta_braces(self, delta_text: str) -> str: """ Clean delta text by removing excessive closing braces. Args: delta_text: Delta text to clean Returns: Cleaned delta text """ if not delta_text: return delta_text delta_stripped = delta_text.strip() if delta_stripped and all(c in "}\n\r\t " for c in delta_stripped): brace_count = delta_stripped.count("}") if brace_count > 1: return "}\n" if delta_text.endswith("\n") else "}" return delta_text def extract_tool_calls( self, model_output: str, request: ChatCompletionRequest, ) -> ExtractedToolCallInformation: """ Extract tool calls from model output for non-streaming mode. Args: model_output: Complete model output request: Chat completion request Returns: ExtractedToolCallInformation containing tool calls and content """ processed_output = self.preprocess_model_output(model_output) if self.tool_call_start_token not in processed_output: return ExtractedToolCallInformation( tools_called=False, tool_calls=[], content=model_output ) try: function_call_tuples = self.tool_call_regex.findall(processed_output) raw_function_calls = [] for match in function_call_tuples: tool_call_content = match[0] if match[0] else match[1] if tool_call_content.strip(): lines = tool_call_content.strip().split("\n") for line in lines: line = line.strip() if line and line.startswith("{") and line.endswith("}"): try: parsed_call = json.loads(line) raw_function_calls.append(parsed_call) except json.JSONDecodeError: continue tool_calls = [] for function_call in raw_function_calls: if "name" in function_call and "arguments" in function_call: tool_calls.append( ToolCall( type="function", function=FunctionCall( name=function_call["name"], arguments=json.dumps( function_call["arguments"], ensure_ascii=False ), ), ) ) processed_pos = processed_output.find(self.tool_call_start_token) if processed_pos != -1: processed_content = processed_output[:processed_pos].strip() if processed_content: lines = processed_content.split("\n") for line in reversed(lines): line = line.strip() if line: pos = model_output.find(line) if pos != -1: content = model_output[: pos + len(line)] break else: content = "" else: content = "" else: content = model_output return ExtractedToolCallInformation( tools_called=len(tool_calls) > 0, tool_calls=tool_calls, content=content.strip() if content.strip() else None, ) except Exception: logger.exception( "An unexpected error occurred during tool call extraction." ) return ExtractedToolCallInformation( tools_called=False, tool_calls=[], content=model_output ) def _update_thinking_state(self, text: str) -> None: """ Update the thinking tag state based on text content. Args: text: Text to analyze for thinking tags """ open_count = text.count("") close_count = text.count("") self.in_thinking_tag = open_count > close_count or ( open_count == close_count and text.endswith("") ) def _is_potential_tag_start(self, text: str) -> bool: """ Check if text might be the start of a tool call tag. Args: text: Text to check Returns: True if text could be the start of a tool call tag """ for tag in [self.tool_call_start_token, self.tool_call_end_token]: if any( tag.startswith(text[-i:]) for i in range(1, min(len(text) + 1, len(tag))) ): return True return False def _should_buffer_content(self, delta_text: str) -> bool: """ Determine if content should be buffered for later processing. Args: delta_text: Delta text to check Returns: True if content should be buffered """ if self.in_thinking_tag: return False return bool( self.pending_buffer or self.tool_call_start_token in delta_text or self.tool_call_end_token in delta_text or delta_text.startswith("<") ) def _split_content_for_buffering(self, delta_text: str) -> tuple[str, str]: """ Split delta text into safe content and potential tag content. Args: delta_text: Delta text to split Returns: Tuple of (safe_content, potential_tag_content) """ if self.in_thinking_tag: return delta_text, "" for tag in [self.tool_call_start_token, self.tool_call_end_token]: for i in range(1, len(tag)): tag_prefix = tag[:i] pos = delta_text.rfind(tag_prefix) if pos != -1 and tag.startswith(delta_text[pos:]): return delta_text[:pos], delta_text[pos:] return delta_text, "" def _process_buffer(self, new_content: str) -> str: """ Process buffered content and return output content. Args: new_content: New content to add to buffer Returns: Processed output content """ self.pending_buffer += new_content output_content = "" if self.in_thinking_tag: output_content = self.pending_buffer self.pending_buffer = "" return output_content while self.pending_buffer: start_pos = self.pending_buffer.find(self.tool_call_start_token) end_pos = self.pending_buffer.find(self.tool_call_end_token) if start_pos != -1 and (end_pos == -1 or start_pos < end_pos): tag_pos, tag_len = start_pos, len(self.tool_call_start_token) elif end_pos != -1: tag_pos, tag_len = end_pos, len(self.tool_call_end_token) else: if self._is_potential_tag_start(self.pending_buffer): break output_content += self.pending_buffer self.pending_buffer = "" break output_content += self.pending_buffer[:tag_pos] self.pending_buffer = self.pending_buffer[tag_pos + tag_len :] return output_content def _reset_streaming_state(self) -> None: """Reset the streaming state to initial values.""" self.streaming_state = { "current_tool_index": -1, "tool_ids": [], "sent_tools": [], } def _advance_to_next_tool(self) -> None: """Advance to the next tool in the streaming sequence.""" self.streaming_state["current_tool_index"] = ( int(self.streaming_state["current_tool_index"]) + 1 ) def _set_current_tool_index(self, index: int) -> None: """ Set the current tool index. Args: index: Tool index to set """ self.streaming_state["current_tool_index"] = index def _get_current_tool_index(self) -> int: """ Get the current tool index. Returns: Current tool index """ return int(self.streaming_state["current_tool_index"]) def _get_next_unsent_tool_index(self, tool_count: int) -> int: """ Get the index of the next unsent tool. Args: tool_count: Total number of tools Returns: Index of next unsent tool, or -1 if all tools sent """ sent_tools = list(self.streaming_state["sent_tools"]) for i in range(tool_count): if i < len(sent_tools): if not sent_tools[i]["sent_name"]: return i else: return i return -1 def _ensure_state_arrays(self, tool_count: int) -> None: """ Ensure state arrays have sufficient capacity for tool_count tools. Args: tool_count: Number of tools to prepare for """ sent_tools = list(self.streaming_state["sent_tools"]) tool_ids = list(self.streaming_state["tool_ids"]) while len(sent_tools) < tool_count: sent_tools.append( { "sent_name": False, "sent_arguments": "", "id": make_tool_call_id(), } ) while len(tool_ids) < tool_count: tool_ids.append(None) self.streaming_state["sent_tools"] = sent_tools self.streaming_state["tool_ids"] = tool_ids def _detect_tools_in_text(self, text: str) -> int: """ Detect the number of tools in text by counting name patterns. Args: text: Text to analyze Returns: Number of tools detected """ matches = self.tool_name_pattern.findall(text) return len(matches) def _find_tool_boundaries(self, text: str) -> list[tuple[int, int]]: """ Find the boundaries of tool calls in text. Args: text: Text to analyze Returns: List of (start, end) positions for tool calls """ boundaries = [] i = 0 while i < len(text): if text[i] == "{": start = i depth = 0 has_name = False has_arguments = False while i < len(text): if text[i] == "{": depth += 1 elif text[i] == "}": depth -= 1 if depth == 0: end = i + 1 segment = text[start:end] if '"name"' in segment and '"arguments"' in segment: boundaries.append((start, end)) break if not has_name and '"name"' in text[start : i + 1]: has_name = True if not has_arguments and '"arguments"' in text[start : i + 1]: has_arguments = True i += 1 if depth > 0 and has_name: boundaries.append((start, i)) else: i += 1 return boundaries def _extract_tool_args(self, tool_content: str, args_match: re.Match[str]) -> str: """ Extract tool arguments from tool content. Args: tool_content: Tool call content args_match: Regex match for arguments pattern Returns: Extracted arguments as string """ args_start_pos = args_match.end() remaining_content = tool_content[args_start_pos:] if remaining_content.strip().startswith("{"): depth = 0 for i, char in enumerate(remaining_content): if char == "{": depth += 1 elif char == "}": depth -= 1 if depth == 0: return remaining_content[: i + 1] else: args_end = remaining_content.find("}") if args_end > 0: return remaining_content[:args_end].strip() return remaining_content.rstrip("}").strip() def _get_current_tool_content( self, text: str, tool_index: int ) -> tuple[str | None, str | None]: """ Get the content of a specific tool by index. Args: text: Text containing tool calls tool_index: Index of tool to extract Returns: Tuple of (tool_name, tool_arguments) or (None, None) if not found """ boundaries = self._find_tool_boundaries(text) if tool_index >= len(boundaries): return None, None start, end = boundaries[tool_index] tool_content = text[start:end] name_match = self.tool_name_pattern.search(tool_content) name = name_match.group(1) if name_match else None args_match = self.tool_args_pattern.search(tool_content) if args_match: try: args_text = self._extract_tool_args(tool_content, args_match) return name, args_text except Exception: remaining_content = tool_content[args_match.end() :] args_text = remaining_content.rstrip("}").strip() return name, args_text return name, None def _handle_tool_name_streaming( self, tool_content: str, tool_count: int ) -> DeltaMessage | None: """ Handle streaming of tool names. Args: tool_content: Content containing tool calls tool_count: Total number of tools Returns: DeltaMessage with tool name or None if no tool to stream """ next_idx = self._get_next_unsent_tool_index(tool_count) if next_idx == -1: return None boundaries = self._find_tool_boundaries(tool_content) if next_idx >= len(boundaries): return None tool_name, _ = self._get_current_tool_content(tool_content, next_idx) if not tool_name: return None self._set_current_tool_index(next_idx) sent_tools = list(self.streaming_state["sent_tools"]) tool_ids = list(self.streaming_state["tool_ids"]) tool_id = sent_tools[next_idx]["id"] tool_ids[next_idx] = tool_id sent_tools[next_idx]["sent_name"] = True self.streaming_state["sent_tools"] = sent_tools self.streaming_state["tool_ids"] = tool_ids return DeltaMessage( tool_calls=[ DeltaToolCall( index=next_idx, type="function", id=tool_id, function=DeltaFunctionCall(name=tool_name).model_dump( exclude_none=True ), ) ] ) def _handle_tool_args_streaming( self, tool_content: str, tool_count: int ) -> DeltaMessage | None: """ Handle streaming of tool arguments. Args: tool_content: Content containing tool calls tool_count: Total number of tools Returns: DeltaMessage with tool arguments or None if no arguments to stream """ current_idx = self._get_current_tool_index() if current_idx < 0 or current_idx >= tool_count: return None tool_name, tool_args = self._get_current_tool_content(tool_content, current_idx) if not tool_name or tool_args is None: return None sent_tools = list(self.streaming_state["sent_tools"]) if not sent_tools[current_idx]["sent_name"]: return None clean_args = self._clean_duplicate_braces(tool_args) sent_args = sent_tools[current_idx]["sent_arguments"] if clean_args != sent_args: if sent_args and clean_args.startswith(sent_args): args_delta = extract_intermediate_diff(clean_args, sent_args) if args_delta: args_delta = self._clean_delta_braces(args_delta) sent_tools[current_idx]["sent_arguments"] = clean_args self.streaming_state["sent_tools"] = sent_tools if clean_args.endswith("}"): self._advance_to_next_tool() return DeltaMessage( tool_calls=[ DeltaToolCall( index=current_idx, function=DeltaFunctionCall( arguments=args_delta ).model_dump(exclude_none=True), ) ] ) elif not sent_args and clean_args: clean_args_delta = self._clean_delta_braces(clean_args) sent_tools[current_idx]["sent_arguments"] = clean_args self.streaming_state["sent_tools"] = sent_tools if clean_args.endswith("}"): self._advance_to_next_tool() return DeltaMessage( tool_calls=[ DeltaToolCall( index=current_idx, function=DeltaFunctionCall( arguments=clean_args_delta ).model_dump(exclude_none=True), ) ] ) return None def _is_end_tool_calls(self, current_text: str) -> bool: if self.tool_call_end_token not in current_text: return False end_token_positions = [] search_start = 0 while True: pos = current_text.find(self.tool_call_end_token, search_start) if pos == -1: break end_token_positions.append(pos) search_start = pos + 1 think_regions = [] for match in re.finditer( self.thinking_tag_pattern, current_text, flags=re.DOTALL ): think_regions.append((match.start(), match.end())) for pos in end_token_positions: in_think = any( pos >= t_start and pos < t_end for t_start, t_end in think_regions ) if not in_think: return True return False def extract_tool_calls_streaming( self, previous_text: str, current_text: str, delta_text: str, previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> DeltaMessage | None: self._update_thinking_state(current_text) if self.in_thinking_tag: return DeltaMessage(content=delta_text) if self._should_buffer_content(delta_text): buffered_output = self._process_buffer(delta_text) return DeltaMessage(content=buffered_output) if buffered_output else None if self._is_end_tool_calls(current_text): return DeltaMessage(content=delta_text) safe_content, potential_tag = self._split_content_for_buffering(delta_text) if potential_tag: self.pending_buffer += potential_tag return DeltaMessage(content=safe_content) if safe_content else None processed_current_text = self.preprocess_model_output(current_text) if self.tool_call_start_token not in processed_current_text: if ( self.tool_call_end_token in delta_text and self.tool_call_start_token in current_text ): return None if delta_text.strip() == "" and self.tool_call_start_token in current_text: return None if ( self._get_current_tool_index() != -1 and self.tool_call_end_token in current_text ): self._reset_streaming_state() return DeltaMessage(content=delta_text) if ( self.tool_call_start_token_id is not None and self.tool_call_start_token_id in delta_token_ids and len(delta_token_ids) == 1 ): return None original_tool_start = self._find_tool_start_outside_thinking(current_text) if original_tool_start is None: return None content_before_tools = self._extract_content_before_tools( current_text, delta_text, original_tool_start ) if content_before_tools: return DeltaMessage(content=content_before_tools) try: tool_content = self._extract_tool_content(current_text, original_tool_start) current_tools_count = self._detect_tools_in_text(tool_content) if current_tools_count == 0: return None if self._get_current_tool_index() == -1: self._reset_streaming_state() self._ensure_state_arrays(current_tools_count) return self._handle_tool_name_streaming( tool_content, current_tools_count ) or self._handle_tool_args_streaming(tool_content, current_tools_count) except Exception: logger.exception( "An unexpected error occurred ", "during streaming tool call handling." ) return None def _find_tool_start_outside_thinking(self, current_text: str) -> int | None: """ Find the start position of tool calls outside of thinking tags. Args: current_text: Current text to search Returns: Position of tool call start or None if not found """ search_start = 0 while True: pos = current_text.find(self.tool_call_start_token, search_start) if pos == -1: return None think_regions = [ (m.start(), m.end()) for m in re.finditer( r"(.*?)", current_text, flags=re.DOTALL ) ] in_think = any( pos >= t_start and pos < t_end for t_start, t_end in think_regions ) if not in_think: return pos search_start = pos + 1 def _extract_content_before_tools( self, current_text: str, delta_text: str, tool_start: int ) -> str | None: """ Extract content that appears before tool calls. Args: current_text: Current text delta_text: Delta text tool_start: Start position of tools Returns: Content before tools or None """ if tool_start > 0: delta_start_pos = len(current_text) - len(delta_text) if delta_start_pos < tool_start: content_part = delta_text if delta_start_pos + len(delta_text) > tool_start: content_part = delta_text[: tool_start - delta_start_pos] return content_part if content_part else None return None def _extract_tool_content(self, current_text: str, tool_start: int) -> str: """ Extract tool content from current text starting at tool_start. Args: current_text: Current text tool_start: Start position of tool calls Returns: Extracted tool content """ tool_content_start = tool_start + len(self.tool_call_start_token) tool_content = current_text[tool_content_start:] end_pos = tool_content.find(self.tool_call_end_token) if end_pos != -1: tool_content = tool_content[:end_pos] return tool_content