# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # ruff: noqa: E501, SIM102 import json from collections.abc import Sequence from typing import Any import regex as re from vllm.entrypoints.openai.protocol import ( ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, ExtractedToolCallInformation, FunctionCall, ToolCall, ) from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike from vllm.tool_parsers.abstract_tool_parser import ( ToolParser, ) from vllm.tool_parsers.utils import consume_space from vllm.utils import random_uuid logger = init_logger(__name__) class HunyuanA13BToolParser(ToolParser): def __init__(self, tokenizer: TokenizerLike): super().__init__(tokenizer) # Initialize state for streaming mode self.prev_tool_calls: list[dict] = [] self.current_tool_id = -1 self.current_tool_name_sent = False self.streamed_args: list[str] = [] # Track arguments sent for each tool # For backward compatibility with tests self.current_tools_sent: list[bool] = [] # For backward compatibility with serving code self.prev_tool_call_arr = [] # Regex patterns for preprocessing self.answer_tool_calls_pattern = re.compile( r"([\s\S]*?)", re.DOTALL ) self.tool_name_reg = re.compile(r'"name"\s*:\s*"([^"]+)"') self.tool_empty_arg_reg = re.compile( r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*\{\s*\}' ) # TODO: not support nested json object in fc arguments. self.tool_non_empty_arg_reg = re.compile( r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*(\{(?:[^{}]|(?:\{[^{}]*\}))*\})' ) self.bot_string = "" # Define streaming state type to be initialized later self.streaming_state: dict[str, Any] = { "current_tool_index": -1, "tool_ids": [], "sent_tools": [], } def preprocess_model_output( self, model_output: str ) -> tuple[str | None, str | None]: # find the location tool call for match in self.answer_tool_calls_pattern.finditer(model_output): start, end = match.span() # check tool_calls whether in side of think_regions = [ (m.start(), m.end()) for m in re.finditer( r"(.*?)", model_output, flags=re.DOTALL ) ] in_think = any( start > t_start and end < t_end for t_start, t_end in think_regions ) if not in_think: content = model_output[:start] tool_calls_content = match.group(1).strip() try: json.loads(tool_calls_content) return content, tool_calls_content except Exception: continue return model_output, None def extract_tool_calls( self, model_output: str, request: ChatCompletionRequest ) -> ExtractedToolCallInformation: """ Extract tool calls from a complete model output. """ try: # Preprocess the model output content, potential_tool_calls = self.preprocess_model_output(model_output) if not potential_tool_calls: # some text should be filtered out for no function call # this text is in a13b's chat template. if content: content = content.replace("助手:", "", 1) return ExtractedToolCallInformation( tools_called=False, tool_calls=[], content=content ) # Parse the potential tool calls as JSON tool_calls_data = json.loads(potential_tool_calls) # Ensure it's an array if not isinstance(tool_calls_data, list): logger.debug("Tool calls data is not an array") return ExtractedToolCallInformation( tools_called=False, tool_calls=[], content=content or model_output, ) tool_calls: list[ToolCall] = [] for idx, call in enumerate(tool_calls_data): if ( not isinstance(call, dict) or "name" not in call or "arguments" not in call ): continue tool_call = ToolCall( id=f"call_{random_uuid()}", type="function", function=FunctionCall( name=call["name"], arguments=( json.dumps(call["arguments"]) if isinstance(call["arguments"], dict) else call["arguments"] ), ), ) tool_calls.append(tool_call) if not content or len(content.strip()) == 0: # clear the whitespace content. content = None return ExtractedToolCallInformation( tools_called=len(tool_calls) > 0, tool_calls=tool_calls, content=content, ) except Exception: return ExtractedToolCallInformation( tools_called=False, tool_calls=[], content=model_output ) def extract_tool_calls_streaming( self, previous_text: str, current_text: str, delta_text: str, previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> DeltaMessage | None: """ Extract tool calls for streaming mode. """ start_idx = consume_space(0, current_text) if current_text[start_idx:].startswith(self.bot_string): start_idx = consume_space(start_idx + len(self.bot_string), current_text) if ( not current_text or start_idx >= len(current_text) or current_text[start_idx] != "[" ): return DeltaMessage(content=delta_text) self._try_parse_json_tools(current_text[start_idx:]) test_delta = self._handle_test_compatibility(current_text) if test_delta: return test_delta name_matches = list(self.tool_name_reg.finditer(current_text)) tool_count = len(name_matches) if tool_count == 0: return None self._ensure_state_arrays(tool_count) current_idx = self.streaming_state["current_tool_index"] name_delta = self._handle_tool_name_streaming( current_idx, tool_count, name_matches ) if name_delta: return name_delta args_delta = self._handle_tool_args_streaming( current_text, current_idx, tool_count ) if args_delta: return args_delta return None def _try_parse_json_tools(self, current_text: str): try: parsed_tools = json.loads(current_text) if isinstance(parsed_tools, list): self.prev_tool_call_arr = parsed_tools except json.JSONDecodeError: pass def _handle_test_compatibility(self, current_text: str): if len(self.current_tools_sent) > 0: if ( len(self.current_tools_sent) == 1 and self.current_tools_sent[0] is False ): name_match = self.tool_name_reg.search(current_text) if name_match: function_name = name_match.group(1) tool_id = f"chatcmpl-tool-{random_uuid()}" delta = DeltaMessage( tool_calls=[ DeltaToolCall( index=0, type="function", id=tool_id, function=DeltaFunctionCall( name=function_name ).model_dump(exclude_none=True), ) ] ) self.current_tools_sent = [True] self.current_tool_id = 0 self.streaming_state["current_tool_index"] = 0 if len(self.streaming_state["sent_tools"]) == 0: self.streaming_state["sent_tools"].append( { "sent_name": True, "sent_arguments_prefix": False, "sent_arguments": "", } ) else: self.streaming_state["sent_tools"][0]["sent_name"] = True self.current_tool_name_sent = True return delta return None def _ensure_state_arrays(self, tool_count: int): while len(self.streaming_state["sent_tools"]) < tool_count: self.streaming_state["sent_tools"].append( { "sent_name": False, "sent_arguments_prefix": False, "sent_arguments": "", } ) while len(self.streaming_state["tool_ids"]) < tool_count: self.streaming_state["tool_ids"].append(None) def _handle_tool_name_streaming( self, current_idx: int, tool_count: int, name_matches ): if current_idx == -1 or current_idx < tool_count - 1: next_idx = current_idx + 1 if ( next_idx < tool_count and not self.streaming_state["sent_tools"][next_idx]["sent_name"] ): self.streaming_state["current_tool_index"] = next_idx self.current_tool_id = next_idx current_idx = next_idx tool_name = name_matches[current_idx].group(1) tool_id = f"call_{current_idx}_{random_uuid()}" self.streaming_state["tool_ids"][current_idx] = tool_id delta = DeltaMessage( tool_calls=[ DeltaToolCall( index=current_idx, type="function", id=tool_id, function=DeltaFunctionCall(name=tool_name).model_dump( exclude_none=True ), ) ] ) self.streaming_state["sent_tools"][current_idx]["sent_name"] = True self.current_tool_name_sent = True while len(self.streamed_args) <= current_idx: self.streamed_args.append("") return delta return None def _handle_tool_args_streaming( self, current_text: str, current_idx: int, tool_count: int ): if current_idx >= 0 and current_idx < tool_count: empty_args_match = self.tool_empty_arg_reg.search(current_text) if empty_args_match and empty_args_match.start() > 0: for i in range(tool_count): if i == current_idx: if not self.streaming_state["sent_tools"][current_idx][ "sent_arguments_prefix" ]: self.streaming_state["sent_tools"][current_idx][ "sent_arguments_prefix" ] = True self.streaming_state["sent_tools"][current_idx][ "sent_arguments" ] = "{}" while len(self.streamed_args) <= current_idx: self.streamed_args.append("") self.streamed_args[current_idx] += "{}" delta = DeltaMessage( tool_calls=[ DeltaToolCall( index=current_idx, function=DeltaFunctionCall( arguments="{}" ).model_dump(exclude_none=True), ) ] ) if current_idx < tool_count - 1: self.streaming_state["current_tool_index"] += 1 self.current_tool_id = self.streaming_state[ "current_tool_index" ] return delta args_matches = list(self.tool_non_empty_arg_reg.finditer(current_text)) if current_idx < len(args_matches): args_text = args_matches[current_idx].group(1) is_last_tool = current_idx == tool_count - 1 if not is_last_tool: next_tool_pos = current_text.find( "},{", args_matches[current_idx].start() ) if next_tool_pos != -1: args_end_pos = next_tool_pos + 1 args_text = ( current_text[ args_matches[current_idx].start() : args_end_pos ] .split('"arguments":')[1] .strip() ) sent_args = self.streaming_state["sent_tools"][current_idx][ "sent_arguments" ] if not self.streaming_state["sent_tools"][current_idx][ "sent_arguments_prefix" ] and args_text.startswith("{"): self.streaming_state["sent_tools"][current_idx][ "sent_arguments_prefix" ] = True self.streaming_state["sent_tools"][current_idx][ "sent_arguments" ] = "{" while len(self.streamed_args) <= current_idx: self.streamed_args.append("") self.streamed_args[current_idx] += "{" delta = DeltaMessage( tool_calls=[ DeltaToolCall( index=current_idx, function=DeltaFunctionCall(arguments="{").model_dump( exclude_none=True ), ) ] ) return delta if args_text.startswith(sent_args): args_diff = args_text[len(sent_args) :] if args_diff: self.streaming_state["sent_tools"][current_idx][ "sent_arguments" ] = args_text while len(self.streamed_args) <= current_idx: self.streamed_args.append("") self.streamed_args[current_idx] += args_diff delta = DeltaMessage( tool_calls=[ DeltaToolCall( index=current_idx, function=DeltaFunctionCall( arguments=args_diff ).model_dump(exclude_none=True), ) ] ) return delta if args_text.endswith("}") and args_text == sent_args: if current_idx < tool_count - 1: self.streaming_state["current_tool_index"] += 1 self.current_tool_id = self.streaming_state[ "current_tool_index" ] return None